1 // SPDX-License-Identifier: GPL-2.0-only
3 * Copyright (C) 2024, SUSE LLC
7 * Implementation of the LZ77 "plain" compression algorithm, as per MS-XCA spec.
9 #include <linux/slab.h>
10 #include <linux/sizes.h>
11 #include <linux/count_zeros.h>
12 #include <linux/unaligned.h>
17 * Compression parameters.
19 #define LZ77_MATCH_MIN_LEN 4
20 #define LZ77_MATCH_MIN_DIST 1
21 #define LZ77_MATCH_MAX_DIST SZ_1K
22 #define LZ77_HASH_LOG 15
23 #define LZ77_HASH_SIZE (1 << LZ77_HASH_LOG)
24 #define LZ77_STEP_SIZE sizeof(u64)
26 static __always_inline u8 lz77_read8(const u8 *ptr)
28 return get_unaligned(ptr);
31 static __always_inline u64 lz77_read64(const u64 *ptr)
33 return get_unaligned(ptr);
36 static __always_inline void lz77_write8(u8 *ptr, u8 v)
38 put_unaligned(v, ptr);
41 static __always_inline void lz77_write16(u16 *ptr, u16 v)
43 put_unaligned_le16(v, ptr);
46 static __always_inline void lz77_write32(u32 *ptr, u32 v)
48 put_unaligned_le32(v, ptr);
51 static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end)
53 const void *start = cur;
56 /* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */
58 diff = lz77_read64(cur) ^ lz77_read64(wnd);
60 cur += LZ77_STEP_SIZE;
61 wnd += LZ77_STEP_SIZE;
66 /* This computes the number of common bytes in @diff. */
67 cur += count_trailing_zeros(diff) >> 3;
70 } while (likely(cur + LZ77_STEP_SIZE < end));
72 while (cur < end && lz77_read8(cur++) == lz77_read8(wnd++))
78 static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u32 len)
85 lz77_write16(dst, dist + len);
91 lz77_write16(dst, dist);
96 lz77_write8(dst, umin(len, 15));
102 lz77_write8(b, *b | umin(len, 15) << 4);
111 lz77_write8(dst, len);
116 lz77_write8(dst, 0xff);
120 lz77_write16(dst, len);
125 lz77_write16(dst, 0);
127 lz77_write32(dst, len);
132 noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
134 const void *srcp, *end;
135 void *dstp, *nib, *flag_pos;
147 htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL);
157 hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG);
158 wnd = src + htable[hash];
159 htable[hash] = srcp - src;
162 if (dist && dist < LZ77_MATCH_MAX_DIST)
163 len = lz77_match_len(wnd, srcp, end);
165 if (len < LZ77_MATCH_MIN_LEN) {
166 lz77_write8(dstp, lz77_read8(srcp));
173 if (flag_count == 32) {
174 lz77_write32(flag_pos, flag);
184 * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth
187 if (unlikely(dstp - dst >= slen - (slen >> 3))) {
192 dstp = lz77_write_match(dstp, &nib, dist, len);
195 flag = (flag << 1) | 1;
197 if (flag_count == 32) {
198 lz77_write32(flag_pos, flag);
203 } while (likely(srcp + LZ77_STEP_SIZE < end));
206 u32 c = umin(end - srcp, 32 - flag_count);
208 memcpy(dstp, srcp, c);
215 if (flag_count == 32) {
216 lz77_write32(flag_pos, flag);
223 flag <<= (32 - flag_count);
224 flag |= (1 << (32 - flag_count)) - 1;
225 lz77_write32(flag_pos, flag);