BeginnerDSA ยท Lesson 3

Binary Search

Master binary search patterns: classic search, search in rotated arrays, and binary search on answer

Classic Binary Search

Binary search finds a target in a sorted array in O(log n).

def binary_search(arr, target):
    left, right = 0, len(arr) - 1
    while left <= right:
        mid = left + (right - left) // 2  # avoids integer overflow
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1

print(binary_search([1, 3, 5, 7, 9, 11], 7))   # 3
print(binary_search([1, 3, 5, 7, 9, 11], 6))   # -1

Find First and Last Position

def search_range(nums, target):
    def find_first(nums, target):
        left, right = 0, len(nums) - 1
        result = -1
        while left <= right:
            mid = (left + right) // 2
            if nums[mid] == target:
                result = mid
                right = mid - 1   # keep searching left
            elif nums[mid] < target:
                left = mid + 1
            else:
                right = mid - 1
        return result

    def find_last(nums, target):
        left, right = 0, len(nums) - 1
        result = -1
        while left <= right:
            mid = (left + right) // 2
            if nums[mid] == target:
                result = mid
                left = mid + 1    # keep searching right
            elif nums[mid] < target:
                left = mid + 1
            else:
                right = mid - 1
        return result

    return [find_first(nums, target), find_last(nums, target)]

print(search_range([5,7,7,8,8,10], 8))  # [3, 4]
print(search_range([5,7,7,8,8,10], 6))  # [-1, -1]

Search in Rotated Sorted Array

def search_rotated(nums, target):
    left, right = 0, len(nums) - 1
    while left <= right:
        mid = (left + right) // 2
        if nums[mid] == target:
            return mid
        # Left half is sorted
        if nums[left] <= nums[mid]:
            if nums[left] <= target < nums[mid]:
                right = mid - 1
            else:
                left = mid + 1
        # Right half is sorted
        else:
            if nums[mid] < target <= nums[right]:
                left = mid + 1
            else:
                right = mid - 1
    return -1

print(search_rotated([4,5,6,7,0,1,2], 0))  # 4
print(search_rotated([4,5,6,7,0,1,2], 3))  # -1

Binary Search on Answer

def sqrt_integer(x):
    """Find floor(sqrt(x)) without math library"""
    if x < 2:
        return x
    left, right = 1, x // 2
    while left <= right:
        mid = (left + right) // 2
        if mid * mid == x:
            return mid
        elif mid * mid < x:
            left = mid + 1
        else:
            right = mid - 1
    return right  # floor

print(sqrt_integer(8))   # 2
print(sqrt_integer(16))  # 4

def koko_eating(piles, h):
    """Minimum eating speed to finish all bananas in h hours"""
    def can_finish(speed):
        return sum((p + speed - 1) // speed for p in piles) <= h

    left, right = 1, max(piles)
    while left < right:
        mid = (left + right) // 2
        if can_finish(mid):
            right = mid  # try slower
        else:
            left = mid + 1  # need faster
    return left

print(koko_eating([3,6,7,11], 8))   # 4
print(koko_eating([30,11,23,4,20], 5))  # 30

Exercises

Exercise 1: Find Minimum in Rotated Array

Solution:

def find_min_rotated(nums):
    left, right = 0, len(nums) - 1
    while left < right:
        mid = (left + right) // 2
        if nums[mid] > nums[right]:
            left = mid + 1
        else:
            right = mid
    return nums[left]

print(find_min_rotated([3,4,5,1,2]))     # 1
print(find_min_rotated([4,5,6,7,0,1,2])) # 0

Exercise 2: Peak Element

Find any peak element (greater than neighbors).

Solution:

def find_peak(nums):
    left, right = 0, len(nums) - 1
    while left < right:
        mid = (left + right) // 2
        if nums[mid] > nums[mid + 1]:
            right = mid
        else:
            left = mid + 1
    return left

print(find_peak([1, 2, 3, 1]))   # 2 (value=3)
print(find_peak([1, 2, 1, 3, 5, 6, 4]))  # 5 (value=6)