Skip to content

cp-snippets

Arrays

DifferenceArray

Helper class for range updates using difference array technique.

Source code in cp_snippets/arrays.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class DifferenceArray:
    """
    Helper class for range updates using difference array technique.
    """
    def __init__(self, arr: List[int]):
        self.n = len(arr)
        # diff array size is n+1 to handle updates ending at n-1 effortlessly
        self.diff = [0] * (self.n + 1)
        if self.n == 0:
            return

        self.diff[0] = arr[0]
        for i in range(1, self.n):
            self.diff[i] = arr[i] - arr[i - 1]

    def update(self, l: int, r: int, val: int):
        """
        Adds val to arr[l...r] (inclusive).
        """
        if self.n == 0:
            raise ValueError("Cannot update an empty DifferenceArray")
        if not (0 <= l <= r < self.n):
            raise IndexError("update range out of bounds")
        self.diff[l] += val
        self.diff[r + 1] -= val

    def get_array(self) -> List[int]:
        """
        Reconstructs and returns the modified array.
        """
        if self.n == 0:
            return []
        arr = [0] * self.n
        arr[0] = self.diff[0]
        for i in range(1, self.n):
            arr[i] = arr[i - 1] + self.diff[i]
        return arr

get_array()

Reconstructs and returns the modified array.

Source code in cp_snippets/arrays.py
100
101
102
103
104
105
106
107
108
109
110
def get_array(self) -> List[int]:
    """
    Reconstructs and returns the modified array.
    """
    if self.n == 0:
        return []
    arr = [0] * self.n
    arr[0] = self.diff[0]
    for i in range(1, self.n):
        arr[i] = arr[i - 1] + self.diff[i]
    return arr

update(l, r, val)

Adds val to arr[l...r] (inclusive).

Source code in cp_snippets/arrays.py
89
90
91
92
93
94
95
96
97
98
def update(self, l: int, r: int, val: int):
    """
    Adds val to arr[l...r] (inclusive).
    """
    if self.n == 0:
        raise ValueError("Cannot update an empty DifferenceArray")
    if not (0 <= l <= r < self.n):
        raise IndexError("update range out of bounds")
    self.diff[l] += val
    self.diff[r + 1] -= val

Binary search in a sorted array.

Parameters:

Name Type Description Default
arr Sequence[T]

Sorted sequence.

required
target T

Value to search.

required
key Optional[Callable[[T], T]]

Optional transform applied to elements before comparison.

None

Returns:

Type Description
int

Index of target if found, otherwise -1.

Source code in cp_snippets/arrays.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def binary_search(arr: Sequence[T], target: T, *, key: Optional[Callable[[T], T]] = None) -> int:
    """Binary search in a sorted array.

    Args:
        arr: Sorted sequence.
        target: Value to search.
        key: Optional transform applied to elements before comparison.

    Returns:
        Index of `target` if found, otherwise -1.
    """
    lo, hi = 0, len(arr) - 1
    if key is None:
        while lo <= hi:
            mid = (lo + hi) // 2
            if arr[mid] == target:
                return mid
            if arr[mid] < target:
                lo = mid + 1
            else:
                hi = mid - 1
        return -1

    while lo <= hi:
        mid = (lo + hi) // 2
        v = key(arr[mid])
        if v == target:
            return mid
        if v < target:
            lo = mid + 1
        else:
            hi = mid - 1
    return -1

get_prefix_sum(arr)

Computes the prefix sum array P where P[i] is the sum of arr[0...i-1]. P[0] = 0.

Source code in cp_snippets/arrays.py
55
56
57
58
59
60
61
62
63
64
def get_prefix_sum(arr: List[int]) -> List[int]:
    """
    Computes the prefix sum array P where P[i] is the sum of arr[0...i-1].
    P[0] = 0.
    """
    n = len(arr)
    prefix_sum = [0] * (n + 1)
    for i in range(n):
        prefix_sum[i + 1] = prefix_sum[i] + arr[i]
    return prefix_sum

lower_bound(arr, target)

Returns the first index i such that arr[i] >= target in a sorted array.

Source code in cp_snippets/arrays.py
46
47
48
def lower_bound(arr: Sequence[T], target: T) -> int:
    """Returns the first index i such that arr[i] >= target in a sorted array."""
    return bisect_left(arr, target)

max_subarray_sum(arr)

Finds the maximum subarray sum using Kadane's Algorithm. Returns 0 if the array is empty. For arrays with all negative numbers, returns the max single element.

Source code in cp_snippets/arrays.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def max_subarray_sum(arr: List[int]) -> int:
    """
    Finds the maximum subarray sum using Kadane's Algorithm.
    Returns 0 if the array is empty. For arrays with all negative numbers, returns the max single element.
    """
    if not arr:
        return 0

    max_so_far = arr[0]
    current_max = arr[0]

    for i in range(1, len(arr)):
        current_max = max(arr[i], current_max + arr[i])
        max_so_far = max(max_so_far, current_max)

    return max_so_far

prefix_sum_2d(grid)

Builds 2D prefix sums.

Returns ps of shape (n+1) x (m+1) where: ps[i][j] = sum of grid[0..i-1][0..j-1]

Source code in cp_snippets/arrays.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def prefix_sum_2d(grid: List[List[int]]) -> List[List[int]]:
    """Builds 2D prefix sums.

    Returns `ps` of shape (n+1) x (m+1) where:
    ps[i][j] = sum of grid[0..i-1][0..j-1]
    """
    if not grid or not grid[0]:
        return [[0]]

    n, m = len(grid), len(grid[0])
    ps = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n):
        row_acc = 0
        for j in range(m):
            row_acc += grid[i][j]
            ps[i + 1][j + 1] = ps[i][j + 1] + row_acc
    return ps

query_range_sum(prefix_sum, l, r)

Returns the sum of arr[l...r] (inclusive) using the prefix sum array.

Source code in cp_snippets/arrays.py
67
68
69
70
71
def query_range_sum(prefix_sum: List[int], l: int, r: int) -> int:
    """
    Returns the sum of arr[l...r] (inclusive) using the prefix sum array.
    """
    return prefix_sum[r + 1] - prefix_sum[l]

query_range_sum_2d(ps, r1, c1, r2, c2)

Rectangle sum query on a 2D prefix sum array.

Parameters:

Name Type Description Default
ps List[List[int]]

2D prefix sum array from prefix_sum_2d.

required
r1, c1

Top-left (inclusive).

required
r2, c2

Bottom-right (inclusive).

required
Source code in cp_snippets/arrays.py
132
133
134
135
136
137
138
139
140
def query_range_sum_2d(ps: List[List[int]], r1: int, c1: int, r2: int, c2: int) -> int:
    """Rectangle sum query on a 2D prefix sum array.

    Args:
        ps: 2D prefix sum array from `prefix_sum_2d`.
        r1, c1: Top-left (inclusive).
        r2, c2: Bottom-right (inclusive).
    """
    return ps[r2 + 1][c2 + 1] - ps[r1][c2 + 1] - ps[r2 + 1][c1] + ps[r1][c1]

sliding_window_max(arr, k)

Calculates the maximum of every window of size k using a deque.

Source code in cp_snippets/arrays.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def sliding_window_max(arr: List[int], k: int) -> List[int]:
    """
    Calculates the maximum of every window of size k using a deque.
    """
    if not arr or k <= 0 or k > len(arr):
        return []

    result = []
    dq = deque()

    for i in range(len(arr)):
        # Remove elements out of the current window
        while dq and dq[0] < i - k + 1:
            dq.popleft()

        # Remove elements smaller than the current element from the right
        while dq and arr[dq[-1]] < arr[i]:
            dq.pop()

        dq.append(i)

        # The first window is completed at index k-1
        if i >= k - 1:
            result.append(arr[dq[0]])

    return result

sliding_window_sum(arr, k)

Calculates the sum of every window of size k.

Source code in cp_snippets/arrays.py
143
144
145
146
147
148
149
150
151
152
153
154
155
def sliding_window_sum(arr: List[int], k: int) -> List[int]:
    """
    Calculates the sum of every window of size k.
    """
    if not arr or k <= 0 or k > len(arr):
        return []

    window_sum = sum(arr[:k])
    result = [window_sum]
    for i in range(len(arr) - k):
        window_sum = window_sum - arr[i] + arr[i + k]
        result.append(window_sum)
    return result

two_sum_sorted(arr, target)

Finds two indices (i, j) in a sorted array such that arr[i] + arr[j] == target. Returns (i, j) 0-indexed if found, otherwise None.

Source code in cp_snippets/arrays.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def two_sum_sorted(arr: List[int], target: int) -> Optional[Tuple[int, int]]:
    """
    Finds two indices (i, j) in a sorted array such that arr[i] + arr[j] == target.
    Returns (i, j) 0-indexed if found, otherwise None.
    """
    l, r = 0, len(arr) - 1
    while l < r:
        curr_sum = arr[l] + arr[r]
        if curr_sum == target:
            return (l, r)
        elif curr_sum < target:
            l += 1
        else:
            r -= 1
    return None

upper_bound(arr, target)

Returns the first index i such that arr[i] > target in a sorted array.

Source code in cp_snippets/arrays.py
51
52
53
def upper_bound(arr: Sequence[T], target: T) -> int:
    """Returns the first index i such that arr[i] > target in a sorted array."""
    return bisect_right(arr, target)

Bit Utils

check_bit(n, i)

Checks if the ith bit of n is set.

Source code in cp_snippets/bit_utils.py
16
17
18
def check_bit(n: int, i: int) -> bool:
    """Checks if the ith bit of n is set."""
    return (n & (1 << i)) != 0

clear_bit(n, i)

Sets the ith bit of n to 0.

Source code in cp_snippets/bit_utils.py
6
7
8
def clear_bit(n: int, i: int) -> int:
    """Sets the ith bit of n to 0."""
    return n & ~(1 << i)

count_set_bits(n)

Counts the number of set bits (1s) in the binary representation of n.

Note: Requires Python 3.10+ for int.bit_count().

Source code in cp_snippets/bit_utils.py
21
22
23
24
25
26
27
28
29
def count_set_bits(n: int) -> int:
    """Counts the number of set bits (1s) in the binary representation of n.

    Note: Requires Python 3.10+ for int.bit_count().
    """
    try:
        return n.bit_count()
    except AttributeError:
        return bin(n).count('1')

is_power_of_two(n)

Checks if n is a power of two.

Source code in cp_snippets/bit_utils.py
37
38
39
def is_power_of_two(n: int) -> bool:
    """Checks if n is a power of two."""
    return n > 0 and (n & (n - 1) == 0)

lowest_set_bit(n)

Returns the value of the lowest set bit in n (e.g., for 12 (1100), returns 4 (0100)).

Source code in cp_snippets/bit_utils.py
32
33
34
def lowest_set_bit(n: int) -> int:
    """Returns the value of the lowest set bit in n (e.g., for 12 (1100), returns 4 (0100))."""
    return n & -n

set_bit(n, i)

Sets the ith bit of n to 1.

Source code in cp_snippets/bit_utils.py
1
2
3
def set_bit(n: int, i: int) -> int:
    """Sets the ith bit of n to 1."""
    return n | (1 << i)

toggle_bit(n, i)

Toggles the ith bit of n.

Source code in cp_snippets/bit_utils.py
11
12
13
def toggle_bit(n: int, i: int) -> int:
    """Toggles the ith bit of n."""
    return n ^ (1 << i)

DP

knapsack_01(weights, values, capacity)

0/1 knapsack maximum value.

Parameters:

Name Type Description Default
weights Sequence[int]

item weights

required
values Sequence[int]

item values

required
capacity int

knapsack capacity

required

Returns:

Type Description
int

Maximum attainable value.

Complexity

O(n * capacity) time, O(capacity) memory.

Source code in cp_snippets/dp.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def knapsack_01(weights: Sequence[int], values: Sequence[int], capacity: int) -> int:
    """0/1 knapsack maximum value.

    Args:
        weights: item weights
        values: item values
        capacity: knapsack capacity

    Returns:
        Maximum attainable value.

    Complexity:
        O(n * capacity) time, O(capacity) memory.
    """
    if len(weights) != len(values):
        raise ValueError("weights and values must have the same length")
    if capacity < 0:
        raise ValueError("capacity must be non-negative")

    dp = [0] * (capacity + 1)
    for w, v in zip(weights, values):
        if w < 0:
            raise ValueError("weights must be non-negative")
        for c in range(capacity, w - 1, -1):
            dp[c] = max(dp[c], dp[c - w] + v)
    return dp[capacity]

lis_length(arr)

Length of the Longest Increasing Subsequence (strict) in O(n log n).

Source code in cp_snippets/dp.py
 7
 8
 9
10
11
12
13
14
15
16
def lis_length(arr: Sequence[int]) -> int:
    """Length of the Longest Increasing Subsequence (strict) in O(n log n)."""
    tails: List[int] = []
    for x in arr:
        i = bisect_left(tails, x)
        if i == len(tails):
            tails.append(x)
        else:
            tails[i] = x
    return len(tails)

Graphs

UnionFind

Disjoint Set Union (Union-Find) with path compression + union by size.

Source code in cp_snippets/graphs.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class UnionFind:
	"""Disjoint Set Union (Union-Find) with path compression + union by size."""

	def __init__(self, n: int):
		if n < 0:
			raise ValueError("n must be non-negative")
		self.parent = list(range(n))
		self.size = [1] * n

	def find(self, x: int) -> int:
		while self.parent[x] != x:
			self.parent[x] = self.parent[self.parent[x]]
			x = self.parent[x]
		return x

	def union(self, a: int, b: int) -> bool:
		ra, rb = self.find(a), self.find(b)
		if ra == rb:
			return False
		if self.size[ra] < self.size[rb]:
			ra, rb = rb, ra
		self.parent[rb] = ra
		self.size[ra] += self.size[rb]
		return True

	def same(self, a: int, b: int) -> bool:
		return self.find(a) == self.find(b)

bfs_dist(n, adj, src)

Shortest distances from src in an unweighted graph.

Returns a list dist where dist[v] is the number of edges in the shortest path from src to v, or -1 if unreachable.

Source code in cp_snippets/graphs.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def bfs_dist(n: int, adj: Sequence[Sequence[int]], src: int) -> List[int]:
	"""Shortest distances from `src` in an unweighted graph.

	Returns a list `dist` where dist[v] is the number of edges in the shortest
	path from src to v, or -1 if unreachable.
	"""
	dist = [-1] * n
	q: Deque[int] = deque([src])
	dist[src] = 0
	while q:
		u = q.popleft()
		for v in adj[u]:
			if dist[v] != -1:
				continue
			dist[v] = dist[u] + 1
			q.append(v)
	return dist

bfs_path(n, adj, src, dst)

Returns one shortest path (as a list of vertices) from src to dst in an unweighted graph.

Source code in cp_snippets/graphs.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def bfs_path(n: int, adj: Sequence[Sequence[int]], src: int, dst: int) -> Optional[List[int]]:
	"""Returns one shortest path (as a list of vertices) from src to dst in an unweighted graph."""
	parent = [-1] * n
	q: Deque[int] = deque([src])
	parent[src] = src
	while q:
		u = q.popleft()
		if u == dst:
			break
		for v in adj[u]:
			if parent[v] != -1:
				continue
			parent[v] = u
			q.append(v)

	if parent[dst] == -1:
		return None

	path = [dst]
	while path[-1] != src:
		path.append(parent[path[-1]])
	path.reverse()
	return path

connected_components(n, adj)

Connected components in an undirected graph.

Source code in cp_snippets/graphs.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def connected_components(n: int, adj: Sequence[Sequence[int]]) -> List[List[int]]:
	"""Connected components in an undirected graph."""
	seen = [False] * n
	comps: List[List[int]] = []
	for i in range(n):
		if seen[i]:
			continue
		comp: List[int] = []
		stack = [i]
		seen[i] = True
		while stack:
			u = stack.pop()
			comp.append(u)
			for v in adj[u]:
				if not seen[v]:
					seen[v] = True
					stack.append(v)
		comps.append(comp)
	return comps

dfs_iterative(adj, start)

Iterative DFS that returns the visitation order.

Source code in cp_snippets/graphs.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def dfs_iterative(adj: Sequence[Sequence[int]], start: int) -> List[int]:
	"""Iterative DFS that returns the visitation order."""
	n = len(adj)
	seen = [False] * n
	order: List[int] = []
	stack = [start]
	while stack:
		u = stack.pop()
		if seen[u]:
			continue
		seen[u] = True
		order.append(u)
		# push neighbors in reverse for a more "recursive-like" order
		for v in reversed(adj[u]):
			if not seen[v]:
				stack.append(v)
	return order

dijkstra(n, adj, src)

Dijkstra shortest paths for non-negative edge weights.

Parameters:

Name Type Description Default
n int

Number of vertices.

required
adj Sequence[Sequence[Tuple[int, int]]]

Adjacency list where adj[u] contains (v, w).

required
src int

Source vertex.

required

Returns:

Type Description
List[int]

(dist, parent)

List[int]

dist[v] is the shortest distance from src to v (or a large number if unreachable).

Tuple[List[int], List[int]]

parent[v] is previous vertex on a shortest path (or -1 if unreachable, src's parent is src).

Source code in cp_snippets/graphs.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def dijkstra(
	n: int, adj: Sequence[Sequence[Tuple[int, int]]], src: int
) -> Tuple[List[int], List[int]]:
	"""Dijkstra shortest paths for non-negative edge weights.

	Args:
		n: Number of vertices.
		adj: Adjacency list where adj[u] contains (v, w).
		src: Source vertex.

	Returns:
		(dist, parent)
		dist[v] is the shortest distance from src to v (or a large number if unreachable).
		parent[v] is previous vertex on a shortest path (or -1 if unreachable, src's parent is src).
	"""
	INF = 10**30
	dist = [INF] * n
	parent = [-1] * n
	dist[src] = 0
	parent[src] = src
	pq: List[Tuple[int, int]] = [(0, src)]
	while pq:
		d, u = heapq.heappop(pq)
		if d != dist[u]:
			continue
		for v, w in adj[u]:
			nd = d + w
			if nd < dist[v]:
				dist[v] = nd
				parent[v] = u
				heapq.heappush(pq, (nd, v))
	return dist, parent

reconstruct_path(parent, src, dst)

Reconstruct a path from src to dst given a parent array.

The convention is that parent[src] == src, and parent[v] == -1 means unreachable.

Source code in cp_snippets/graphs.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def reconstruct_path(parent: Sequence[int], src: int, dst: int) -> Optional[List[int]]:
	"""Reconstruct a path from `src` to `dst` given a `parent` array.

	The convention is that parent[src] == src, and parent[v] == -1 means unreachable.
	"""
	if dst < 0 or dst >= len(parent):
		raise IndexError("dst out of bounds")
	if parent[dst] == -1:
		return None
	path = [dst]
	while path[-1] != src:
		p = parent[path[-1]]
		if p == -1:
			return None
		path.append(p)
	path.reverse()
	return path

topological_sort_kahn(n, adj)

Topological sort of a directed graph.

Returns:

Type Description
Optional[List[int]]

A topological order list of length n, or None if a cycle exists.

Source code in cp_snippets/graphs.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def topological_sort_kahn(n: int, adj: Sequence[Sequence[int]]) -> Optional[List[int]]:
	"""Topological sort of a directed graph.

	Returns:
		A topological order list of length n, or None if a cycle exists.
	"""
	indeg = [0] * n
	for u in range(n):
		for v in adj[u]:
			indeg[v] += 1

	q: Deque[int] = deque([i for i, d in enumerate(indeg) if d == 0])
	order: List[int] = []
	while q:
		u = q.popleft()
		order.append(u)
		for v in adj[u]:
			indeg[v] -= 1
			if indeg[v] == 0:
				q.append(v)

	if len(order) != n:
		return None
	return order

Math

euler_totient(n, spf)

Computes Euler's Totient function phi(n) using SPF. phi(n) = count of integers <= n coprime to n.

Source code in cp_snippets/math_utils.py
116
117
118
119
120
121
122
123
124
125
126
127
def euler_totient(n, spf):
    """
    Computes Euler's Totient function phi(n) using SPF.
    phi(n) = count of integers <= n coprime to n.
    """
    res = n
    while n > 1:
        p = spf[n]
        res -= res // p
        while n % p == 0:
            n //= p
    return res

gcd(a, b)

Computes the Greatest Common Divisor of a and b.

Source code in cp_snippets/math_utils.py
1
2
3
4
5
def gcd(a, b):
    """Computes the Greatest Common Divisor of a and b."""
    while b:
        a, b = b, a % b
    return a

init_nCr(N, mod)

Precomputes factorials and inverse factorials up to N for nCr calculations. Returns: (fact, invfact) lists.

Source code in cp_snippets/math_utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def init_nCr(N, mod):
    """
    Precomputes factorials and inverse factorials up to N for nCr calculations.
    Returns: (fact, invfact) lists.
    """
    fact = [1] * (N + 1)
    invfact = [1] * (N + 1)

    for i in range(1, N + 1):
        fact[i] = fact[i - 1] * i % mod

    invfact[N] = mod_inv(fact[N], mod)
    for i in range(N, 0, -1):
        invfact[i - 1] = invfact[i] * i % mod

    return fact, invfact

lcm(a, b)

Computes the Least Common Multiple of a and b.

Source code in cp_snippets/math_utils.py
 8
 9
10
11
12
def lcm(a, b):
    """Computes the Least Common Multiple of a and b."""
    if a == 0 or b == 0:
        return 0
    return abs(a * b) // gcd(a, b)

mod_add(a, b, mod)

Computes (a + b) % mod.

Source code in cp_snippets/math_utils.py
15
16
17
def mod_add(a, b, mod):
    """Computes (a + b) % mod."""
    return (a + b) % mod

mod_inv(a, mod)

Computes modular inverse of a modulo mod using Fermat's Little Theorem. Assumes mod is prime.

Source code in cp_snippets/math_utils.py
37
38
39
40
41
42
def mod_inv(a, mod):
    """
    Computes modular inverse of a modulo mod using Fermat's Little Theorem.
    Assumes mod is prime.
    """
    return mod_pow(a, mod - 2, mod)

mod_mul(a, b, mod)

Computes (a * b) % mod.

Source code in cp_snippets/math_utils.py
20
21
22
def mod_mul(a, b, mod):
    """Computes (a * b) % mod."""
    return (a * b) % mod

mod_pow(a, b, mod)

Computes (a^b) % mod using binary exponentiation.

Source code in cp_snippets/math_utils.py
25
26
27
28
29
30
31
32
33
34
def mod_pow(a, b, mod):
    """Computes (a^b) % mod using binary exponentiation."""
    res = 1
    a %= mod
    while b:
        if b & 1:
            res = res * a % mod
        a = a * a % mod
        b >>= 1
    return res

nCr(n, r, fact, invfact, mod)

Computes nCr % mod using precomputed factorials. Returns 0 if r < 0 or r > n.

Source code in cp_snippets/math_utils.py
63
64
65
66
67
68
69
70
def nCr(n, r, fact, invfact, mod):
    """
    Computes nCr % mod using precomputed factorials.
    Returns 0 if r < 0 or r > n.
    """
    if r < 0 or r > n:
        return 0
    return fact[n] * invfact[r] % mod * invfact[n - r] % mod

prime_factorize(x, spf)

Returns prime factorization of x using precomputed SPF array. Returns: Dictionary {prime_factor: exponent}.

Source code in cp_snippets/math_utils.py
103
104
105
106
107
108
109
110
111
112
113
def prime_factorize(x, spf):
    """
    Returns prime factorization of x using precomputed SPF array.
    Returns: Dictionary {prime_factor: exponent}.
    """
    factors = {}
    while x > 1:
        p = spf[x]
        factors[p] = factors.get(p, 0) + 1
        x //= p
    return factors

sieve(n)

Sieve of Eratosthenes to find primes up to n. Returns: is_prime boolean list where is_prime[i] is True if i is prime.

Source code in cp_snippets/math_utils.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def sieve(n):
    """
    Sieve of Eratosthenes to find primes up to n.
    Returns: is_prime boolean list where is_prime[i] is True if i is prime.
    """
    is_prime = [True] * (n + 1)
    is_prime[0] = is_prime[1] = False

    for i in range(2, int(n ** 0.5) + 1):
        if is_prime[i]:
            for j in range(i * i, n + 1, i):
                is_prime[j] = False

    return is_prime

spf_sieve(n)

Computes Smallest Prime Factor (SPF) for each number up to n. Useful for fast prime factorization.

Source code in cp_snippets/math_utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def spf_sieve(n):
    """
    Computes Smallest Prime Factor (SPF) for each number up to n.
    Useful for fast prime factorization.
    """
    spf = list(range(n + 1))
    for i in range(2, int(n ** 0.5) + 1):
        if spf[i] == i:
            for j in range(i * i, n + 1, i):
                if spf[j] == j:
                    spf[j] = i
    return spf

IO

char_matrix(n)

Reads n lines of strings and converts them into a list of list of characters.

Source code in cp_snippets/io.py
38
39
40
def char_matrix(n):
    """Reads n lines of strings and converts them into a list of list of characters."""
    return [list(input().strip()) for _ in range(n)]

fast_reader()

Returns an iterator that yields tokens from stdin one by one.

Source code in cp_snippets/io.py
58
59
60
def fast_reader():
    """Returns an iterator that yields tokens from stdin one by one."""
    return iter(sys.stdin.read().split())

int_input()

Reads a single integer from standard input.

Source code in cp_snippets/io.py
5
6
7
def int_input():
    """Reads a single integer from standard input."""
    return int(input())

list_input(dtype=int)

Reads a line of space-separated values and returns a list. Default element type is int.

Source code in cp_snippets/io.py
15
16
17
def list_input(dtype=int):
    """Reads a line of space-separated values and returns a list. Default element type is int."""
    return list(map(dtype, input().split()))

matrix_input(n, m=None, dtype=int)

Reads a matrix of size n x m. If m is None, reads n lines where each line can have arbitrary number of elements.

Source code in cp_snippets/io.py
28
29
30
31
32
33
34
35
def matrix_input(n, m=None, dtype=int):
    """
    Reads a matrix of size n x m.
    If m is None, reads n lines where each line can have arbitrary number of elements.
    """
    if m is None:
        return [list(map(dtype, input().split())) for _ in range(n)]
    return [list(map(dtype, input().split()))[:m] for _ in range(n)]

multi_input(dtype=int)

Reads a line of space-separated values and returns a map object. Usage: a, b, c = multi_input()

Source code in cp_snippets/io.py
20
21
22
23
24
25
def multi_input(dtype=int):
    """
    Reads a line of space-separated values and returns a map object.
    Usage: a, b, c = multi_input()
    """
    return map(dtype, input().split())

print_list(arr, sep=' ')

Prints elements of a list separated by 'sep'.

Source code in cp_snippets/io.py
48
49
50
def print_list(arr, sep=' '):
    """Prints elements of a list separated by 'sep'."""
    sys.stdout.write(sep.join(map(str, arr)) + '\n')

print_yes_no(cond)

Prints 'YES' if cond is True, else 'NO'.

Source code in cp_snippets/io.py
53
54
55
def print_yes_no(cond):
    """Prints 'YES' if cond is True, else 'NO'."""
    sys.stdout.write("YES\n" if cond else "NO\n")

read_all()

Reads all input from stdin until EOF and returns a list of tokens.

Source code in cp_snippets/io.py
43
44
45
def read_all():
    """Reads all input from stdin until EOF and returns a list of tokens."""
    return sys.stdin.read().split()

str_input()

Reads a single line of string from standard input, stripping leading/trailing whitespace.

Source code in cp_snippets/io.py
10
11
12
def str_input():
    """Reads a single line of string from standard input, stripping leading/trailing whitespace."""
    return input().strip()

Misc

ceil_div(a, b)

Ceiling division for integers.

Works for negative values as well.

Source code in cp_snippets/misc.py
 9
10
11
12
13
14
15
16
def ceil_div(a: int, b: int) -> int:
	"""Ceiling division for integers.

	Works for negative values as well.
	"""
	if b == 0:
		raise ZeroDivisionError("division by zero")
	return -(-a // b)

chunks(arr, size)

Yields consecutive chunks of length size from arr.

Source code in cp_snippets/misc.py
44
45
46
47
48
49
def chunks(arr: Sequence[T], size: int) -> Iterable[Sequence[T]]:
	"""Yields consecutive chunks of length `size` from `arr`."""
	if size <= 0:
		raise ValueError("size must be positive")
	for i in range(0, len(arr), size):
		yield arr[i : i + size]

coordinate_compress(values)

Coordinate compression.

Parameters:

Name Type Description Default
values Sequence[T]

Any sortable values.

required

Returns:

Name Type Description
List[int]

(compressed, uniq, index)

compressed List[T]

list of ints, same length as values

uniq Dict[T, int]

sorted unique values

index Tuple[List[int], List[T], Dict[T, int]]

mapping from value -> compressed index

Source code in cp_snippets/misc.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def coordinate_compress(values: Sequence[T]) -> Tuple[List[int], List[T], Dict[T, int]]:
	"""Coordinate compression.

	Args:
		values: Any sortable values.

	Returns:
		(compressed, uniq, index)
		compressed: list of ints, same length as values
		uniq: sorted unique values
		index: mapping from value -> compressed index
	"""
	uniq = sorted(set(values))
	index = {v: i for i, v in enumerate(uniq)}
	compressed = [index[v] for v in values]
	return compressed, uniq, index

floor_div(a, b)

Floor division for integers (same as a // b, here for symmetry).

Source code in cp_snippets/misc.py
19
20
21
22
23
def floor_div(a: int, b: int) -> int:
	"""Floor division for integers (same as `a // b`, here for symmetry)."""
	if b == 0:
		raise ZeroDivisionError("division by zero")
	return a // b

Trees

BinaryLiftingLCA

Lowest Common Ancestor for a rooted tree via binary lifting.

Complexity

Preprocess: O(n log n) Query LCA: O(log n)

Source code in cp_snippets/trees.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class BinaryLiftingLCA:
	"""Lowest Common Ancestor for a rooted tree via binary lifting.

	Complexity:
		Preprocess: O(n log n)
		Query LCA:  O(log n)
	"""

	def __init__(self, n: int, adj: Sequence[Sequence[int]], root: int = 0):
		self.n = n
		self.root = root
		self.LOG = (n).bit_length()
		self.up = [[-1] * n for _ in range(self.LOG)]
		self.depth = [-1] * n

		q: Deque[int] = deque([root])
		self.depth[root] = 0
		self.up[0][root] = root

		while q:
			u = q.popleft()
			for v in adj[u]:
				if self.depth[v] != -1:
					continue
				self.depth[v] = self.depth[u] + 1
				self.up[0][v] = u
				q.append(v)

		for k in range(1, self.LOG):
			prev = self.up[k - 1]
			cur = self.up[k]
			for v in range(n):
				cur[v] = prev[prev[v]]

	def kth_ancestor(self, v: int, k: int) -> int:
		"""Returns the k-th ancestor of v (0-th ancestor is v)."""
		i = 0
		while k:
			if k & 1:
				v = self.up[i][v]
			k >>= 1
			i += 1
		return v

	def lca(self, a: int, b: int) -> int:
		"""Returns LCA(a, b)."""
		if self.depth[a] < self.depth[b]:
			a, b = b, a

		# lift a
		diff = self.depth[a] - self.depth[b]
		a = self.kth_ancestor(a, diff)
		if a == b:
			return a

		for k in range(self.LOG - 1, -1, -1):
			if self.up[k][a] != self.up[k][b]:
				a = self.up[k][a]
				b = self.up[k][b]

		return self.up[0][a]

	def dist(self, a: int, b: int) -> int:
		"""Number of edges on the path between a and b."""
		c = self.lca(a, b)
		return self.depth[a] + self.depth[b] - 2 * self.depth[c]

dist(a, b)

Number of edges on the path between a and b.

Source code in cp_snippets/trees.py
171
172
173
174
def dist(self, a: int, b: int) -> int:
	"""Number of edges on the path between a and b."""
	c = self.lca(a, b)
	return self.depth[a] + self.depth[b] - 2 * self.depth[c]

kth_ancestor(v, k)

Returns the k-th ancestor of v (0-th ancestor is v).

Source code in cp_snippets/trees.py
143
144
145
146
147
148
149
150
151
def kth_ancestor(self, v: int, k: int) -> int:
	"""Returns the k-th ancestor of v (0-th ancestor is v)."""
	i = 0
	while k:
		if k & 1:
			v = self.up[i][v]
		k >>= 1
		i += 1
	return v

lca(a, b)

Returns LCA(a, b).

Source code in cp_snippets/trees.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def lca(self, a: int, b: int) -> int:
	"""Returns LCA(a, b)."""
	if self.depth[a] < self.depth[b]:
		a, b = b, a

	# lift a
	diff = self.depth[a] - self.depth[b]
	a = self.kth_ancestor(a, diff)
	if a == b:
		return a

	for k in range(self.LOG - 1, -1, -1):
		if self.up[k][a] != self.up[k][b]:
			a = self.up[k][a]
			b = self.up[k][b]

	return self.up[0][a]

FenwickTree

Fenwick Tree (Binary Indexed Tree) for point updates and prefix sums.

Indexing

Public methods accept 0-based indices; internally we use 1-based.

Source code in cp_snippets/trees.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class FenwickTree:
	"""Fenwick Tree (Binary Indexed Tree) for point updates and prefix sums.

	Indexing:
		Public methods accept 0-based indices; internally we use 1-based.
	"""

	def __init__(self, n_or_arr: int | Sequence[int]):
		if isinstance(n_or_arr, int):
			n = n_or_arr
			if n < 0:
				raise ValueError("n must be non-negative")
			self.n = n
			self.bit = [0] * (n + 1)
		else:
			arr = list(n_or_arr)
			self.n = len(arr)
			self.bit = [0] * (self.n + 1)
			for i, v in enumerate(arr):
				self.add(i, v)

	def add(self, idx: int, delta: int) -> None:
		"""Adds `delta` to a[idx]."""
		i = idx + 1
		while i <= self.n:
			self.bit[i] += delta
			i += i & -i

	def sum_prefix(self, r: int) -> int:
		"""Returns sum(a[0..r]) inclusive. If r < 0 returns 0."""
		if r < 0:
			return 0
		i = r + 1
		res = 0
		while i > 0:
			res += self.bit[i]
			i -= i & -i
		return res

	def sum_range(self, l: int, r: int) -> int:
		"""Returns sum(a[l..r]) inclusive."""
		if r < l:
			return 0
		return self.sum_prefix(r) - self.sum_prefix(l - 1)

add(idx, delta)

Adds delta to a[idx].

Source code in cp_snippets/trees.py
29
30
31
32
33
34
def add(self, idx: int, delta: int) -> None:
	"""Adds `delta` to a[idx]."""
	i = idx + 1
	while i <= self.n:
		self.bit[i] += delta
		i += i & -i

sum_prefix(r)

Returns sum(a[0..r]) inclusive. If r < 0 returns 0.

Source code in cp_snippets/trees.py
36
37
38
39
40
41
42
43
44
45
def sum_prefix(self, r: int) -> int:
	"""Returns sum(a[0..r]) inclusive. If r < 0 returns 0."""
	if r < 0:
		return 0
	i = r + 1
	res = 0
	while i > 0:
		res += self.bit[i]
		i -= i & -i
	return res

sum_range(l, r)

Returns sum(a[l..r]) inclusive.

Source code in cp_snippets/trees.py
47
48
49
50
51
def sum_range(self, l: int, r: int) -> int:
	"""Returns sum(a[l..r]) inclusive."""
	if r < l:
		return 0
	return self.sum_prefix(r) - self.sum_prefix(l - 1)

SegmentTree

Iterative segment tree for range queries and point updates.

Default operation is sum; you can supply any associative op with identity e.

Query semantics

query(l, r) returns op over the half-open interval [l, r).

Source code in cp_snippets/trees.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class SegmentTree:
	"""Iterative segment tree for range queries and point updates.

	Default operation is sum; you can supply any associative `op` with identity `e`.

	Query semantics:
		query(l, r) returns op over the half-open interval [l, r).
	"""

	def __init__(
		self,
		arr: Sequence[int],
		op: Callable[[int, int], int] = lambda a, b: a + b,
		e: int = 0,
	):
		self.n = len(arr)
		self.op = op
		self.e = e
		self.size = 1
		while self.size < self.n:
			self.size <<= 1
		self.data = [e] * (2 * self.size)
		# build
		for i in range(self.n):
			self.data[self.size + i] = arr[i]
		for i in range(self.size - 1, 0, -1):
			self.data[i] = op(self.data[2 * i], self.data[2 * i + 1])

	def update(self, idx: int, value: int) -> None:
		"""Sets a[idx] = value."""
		i = self.size + idx
		self.data[i] = value
		i //= 2
		while i:
			self.data[i] = self.op(self.data[2 * i], self.data[2 * i + 1])
			i //= 2

	def query(self, l: int, r: int) -> int:
		"""Returns op(a[l..r))"""
		res_left = self.e
		res_right = self.e
		l += self.size
		r += self.size
		while l < r:
			if l & 1:
				res_left = self.op(res_left, self.data[l])
				l += 1
			if r & 1:
				r -= 1
				res_right = self.op(self.data[r], res_right)
			l //= 2
			r //= 2
		return self.op(res_left, res_right)

query(l, r)

Returns op(a[l..r))

Source code in cp_snippets/trees.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def query(self, l: int, r: int) -> int:
	"""Returns op(a[l..r))"""
	res_left = self.e
	res_right = self.e
	l += self.size
	r += self.size
	while l < r:
		if l & 1:
			res_left = self.op(res_left, self.data[l])
			l += 1
		if r & 1:
			r -= 1
			res_right = self.op(self.data[r], res_right)
		l //= 2
		r //= 2
	return self.op(res_left, res_right)

update(idx, value)

Sets a[idx] = value.

Source code in cp_snippets/trees.py
82
83
84
85
86
87
88
89
def update(self, idx: int, value: int) -> None:
	"""Sets a[idx] = value."""
	i = self.size + idx
	self.data[i] = value
	i //= 2
	while i:
		self.data[i] = self.op(self.data[2 * i], self.data[2 * i + 1])
		i //= 2

Strings

RollingHash

Double rolling hash implementation for string matching and hashing. 1-based indexing for hash queries is often easier, but we stick to 0-based for consistency with Python.

Source code in cp_snippets/strings.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class RollingHash:
    """
    Double rolling hash implementation for string matching and hashing.
    1-based indexing for hash queries is often easier, but we stick to 0-based for consistency with Python.
    """
    def __init__(self, s: str, base1: int = 313, mod1: int = 10**9 + 7, base2: int = 317, mod2: int = 10**9 + 9):
        self.mod1 = mod1
        self.mod2 = mod2
        self.base1 = base1
        self.base2 = base2
        n = len(s)

        self.hash1 = [0] * (n + 1)
        self.hash2 = [0] * (n + 1)
        self.pow1 = [1] * (n + 1)
        self.pow2 = [1] * (n + 1)

        for i in range(n):
            self.hash1[i + 1] = (self.hash1[i] * base1 + ord(s[i])) % mod1
            self.hash2[i + 1] = (self.hash2[i] * base2 + ord(s[i])) % mod2
            self.pow1[i + 1] = (self.pow1[i] * base1) % mod1
            self.pow2[i + 1] = (self.pow2[i] * base2) % mod2

    def get_hash(self, l: int, r: int) -> Tuple[int, int]:
        """
        Returns the double hash of substring s[l...r] (inclusive).
        """
        # hash[r+1] - hash[l] * base^(r-l+1)
        h1 = (self.hash1[r + 1] - self.hash1[l] * self.pow1[r - l + 1]) % self.mod1
        h2 = (self.hash2[r + 1] - self.hash2[l] * self.pow2[r - l + 1]) % self.mod2
        return (h1, h2)

get_hash(l, r)

Returns the double hash of substring s[l...r] (inclusive).

Source code in cp_snippets/strings.py
87
88
89
90
91
92
93
94
def get_hash(self, l: int, r: int) -> Tuple[int, int]:
    """
    Returns the double hash of substring s[l...r] (inclusive).
    """
    # hash[r+1] - hash[l] * base^(r-l+1)
    h1 = (self.hash1[r + 1] - self.hash1[l] * self.pow1[r - l + 1]) % self.mod1
    h2 = (self.hash2[r + 1] - self.hash2[l] * self.pow2[r - l + 1]) % self.mod2
    return (h1, h2)

compute_pi(s)

Computes the prefix function (pi array) for KMP algorithm. pi[i] is the length of the longest proper prefix of s[0...i] that is also a suffix of s[0...i].

Source code in cp_snippets/strings.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def compute_pi(s: str) -> List[int]:
    """
    Computes the prefix function (pi array) for KMP algorithm.
    pi[i] is the length of the longest proper prefix of s[0...i] 
    that is also a suffix of s[0...i].
    """
    m = len(s)
    pi = [0] * m
    for i in range(1, m):
        j = pi[i - 1]
        while j > 0 and s[i] != s[j]:
            j = pi[j - 1]
        if s[i] == s[j]:
            j += 1
        pi[i] = j
    return pi

is_palindrome(s)

Checks if the string s is a palindrome.

Source code in cp_snippets/strings.py
97
98
99
def is_palindrome(s: str) -> bool:
    """Checks if the string s is a palindrome."""
    return s == s[::-1]

Finds all occurrences of pattern in text using KMP algorithm. Returns a list of starting indices (0-based).

Source code in cp_snippets/strings.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def kmp_search(text: str, pattern: str) -> List[int]:
    """
    Finds all occurrences of pattern in text using KMP algorithm.
    Returns a list of starting indices (0-based).
    """
    if not pattern:
        return []

    pi = compute_pi(pattern)
    n, m = len(text), len(pattern)
    matches = []
    j = 0  # index for pattern

    for i in range(n):
        while j > 0 and text[i] != pattern[j]:
            j = pi[j - 1]
        if text[i] == pattern[j]:
            j += 1
        if j == m:
            matches.append(i - m + 1)
            j = pi[j - 1]

    return matches

manacher(s)

Computes Manacher's algorithm to find longest palindromic substring. Returns the P array where P[i] is the length of the palindrome radius centered at T[i]. The input string s is transformed to T = #s[0]#s[1]...#s[n-1]# to handle even length palindromes. The length of the palindrome centered at original index i (mapped to T) is P[2*i + 2] - 1.

Source code in cp_snippets/strings.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def manacher(s: str) -> List[int]:
    """
    Computes Manacher's algorithm to find longest palindromic substring.
    Returns the P array where P[i] is the length of the palindrome radius centered at T[i].
    The input string s is transformed to T = #s[0]#s[1]...#s[n-1]# to handle even length palindromes.
    The length of the palindrome centered at original index i (mapped to T) is P[2*i + 2] - 1.
    """
    if not s:
        return []

    t = '#'.join('^{}$'.format(s))
    n = len(t)
    P = [0] * n
    C = 0
    R = 0

    for i in range(1, n - 1):
        if i < R:
            P[i] = min(R - i, P[2 * C - i])
        else:
            P[i] = 0

        # Attempt to expand palindrome centered at i
        while t[i + 1 + P[i]] == t[i - 1 - P[i]]:
            P[i] += 1

        # If palindrome centered at i expands past R,
        # adjust center based on expanded palindrome.
        if i + P[i] > R:
            C = i
            R = i + P[i]

    # Extract just the radius values for the transformed string
    # We strip the first and last sentinel characters ^ and $ which are not part of original analysis
    # t = ^ # a # b # a # $
    # indices: 0 1 2 3 4 5 6 7 8 
    # s = aba
    # Real useful part is from index 2 to n-2. 
    # But usually standard manacher return is the P array for the T string.
    # To conform to standard competitive programming templates, we return the full P array for T.
    # T here includes ^ and $ which simplifies boundary checks.

    return P

z_function(s)

Computes the Z-function for string s. z[i] is the length of the longest common prefix between s and the suffix of s starting at i.

Source code in cp_snippets/strings.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def z_function(s: str) -> List[int]:
    """
    Computes the Z-function for string s.
    z[i] is the length of the longest common prefix between s and the suffix of s starting at i.
    """
    n = len(s)
    z = [0] * n
    l, r = 0, 0
    for i in range(1, n):
        if i <= r:
            z[i] = min(r - i + 1, z[i - l])
        while i + z[i] < n and s[z[i]] == s[i + z[i]]:
            z[i] += 1
        if i + z[i] - 1 > r:
            l, r = i, i + z[i] - 1
    return z