Loops in Rust
Let's implement a Base64 encoder. We will write a function that takes an array of arbitrary bytes as input and returns an array of symbols from the Base64 alphabet.
As the name implies, there are 64 symbols in the Base64 alphabet. That means that each symbol encodes 6 bits. So we would like to iterate through our input 6 bits at a time. But our input is provided as an array of bytes and a byte is 8 bits. The least common multiple of 6 and 8 is 24, so we will have to step through the input 3 bytes at a time, and write out 4 symbols at each step.
What actually is a Base64 symbol? The Base64 encoding is designed to allow arbitrary binary data to be sent through systems that expect to deal with text. Therefore, each symbol is a character: usually an ASCII character. If we use the UTF-8 encoding for our strings, each character will be one byte.
Thus our function will read 3 bytes at a time and write 4 bytes. This 3:4 ratio is the encoding overhead for Base64. It means that any data will expand to 133% of its unencoded size when encoded using Base64. This is the price we pay to send binary data as text.
We'll start with the simplest case, in which our input is an exact multiple of 3 bytes long.
An implementation in C might look like this:
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
char encoding[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
char *encode(uint8_t *src, size_t src_len) {
assert(src_len % 3 == 0);
size_t dst_len = (src_len / 3) * 4;
char *dst = malloc(dst_len);
size_t si = 0;
size_t di = 0;
while (si < src_len) {
uint32_t val = (src[si] << 16) | (src[si+1] << 8) | src[si+2];
dst[di+0] = encoding[val >> 18 & 0x3F];
dst[di+1] = encoding[val >> 12 & 0x3F];
dst[di+2] = encoding[val >> 6 & 0x3F];
dst[di+3] = encoding[val & 0x3F];
si += 3;
di += 4;
}
return dst;
}
The algorithm is as follows:
- Allocate a suitably sized output array
- Initialize two indices to zero
- Read three bytes from the source array using the source index
- Split the three bytes into four 6-bit values
- Encode the four values and write them out using the destination index
- Increment the source index by three
- Increment the destination index by four
- Repeat from step 3 until we've encoded all the source bytes
Let's take a closer look at the loop. I will use the Compiler Explorer to view the resulting assembly: I'll include a link so that you can look up the complete assembly if you're interested but I'll only include relevant excerpts here.
I will target x86-64; please follow the links to see the exact compiler versions and options used.
The above C implemention gives us this [1]:
; Setup and then...
test r14, r14 ; if (src_len == 0)
je .LBB0_4 ; goto .LBB0_4
mov rcx, rax ; dst is result of malloc
add rcx, 3 ; dst += 3
xor edx, edx ; si = 0
.LBB0_3:
movzx ebx, byte ptr [r15 + rdx] ; load src[si]
movzx edi, byte ptr [r15 + rdx + 1] ; load src[si+1]
shl edi, 8
mov esi, ebx
shl esi, 16
or esi, edi
movzx r8d, byte ptr [r15 + rdx + 2] ; load src[si+2]
or edi, r8d
shr rbx, 2
movzx ebx, byte ptr [rbx + encoding] ; load 1st symbol
mov byte ptr [rcx - 3], bl ; store at dst-3
shr esi, 12
and esi, 63
movzx ebx, byte ptr [rsi + encoding] ; load 2nd symbol
mov byte ptr [rcx - 2], bl ; store at dst-2
shr edi, 6
and edi, 63
movzx ebx, byte ptr [rdi + encoding] ; load 3rd symbol
mov byte ptr [rcx - 1], bl ; store at dst-1
and r8d, 63
movzx ebx, byte ptr [r8 + encoding] ; load 4th symbol
mov byte ptr [rcx], bl ; store at dst
add rdx, 3 ; si += 3
add rcx, 4 ; dst += 4
cmp rdx, r14 ; if (si < src_len)
jb .LBB0_3 ; goto .LBB0_3
.LBB0_4:
; The end
It's pretty much what you'd expect. There's a one-off check to make sure that
src_len
isn't zero and a jump past the whole loop if it is. Inside the loop we
have an index in rdx
that corresponds to the si
index in the C source, and
we have the dst
pointer in rcx
that we increment directly. We increment one
register by 3 and the other by 4 on each iteration through the loop. It's a
fairly direct translation of the C code.
We can write the same algorithm in Rust:
const ENCODING: [u8; 64] = [
b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H',
b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P',
b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X',
b'Y', b'Z', b'a', b'b', b'c', b'd', b'e', b'f',
b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n',
b'o', b'p', b'q', b'r', b's', b't', b'u', b'v',
b'w', b'x', b'y', b'z', b'0', b'1', b'2', b'3',
b'4', b'5', b'6', b'7', b'8', b'9', b'+', b'/'
];
pub fn encode(src: &[u8]) -> Vec<u8> {
assert!(src.len() % 3 == 0);
let dst_len = (src.len() / 3) * 4;
let mut dst = vec![0 as u8; dst_len];
let mut si = 0;
let mut di = 0;
let n = src.len();
loop {
if si >= n {
break;
}
let val = (src[si+0] as u32) << 16 |
(src[si+1] as u32) << 8 |
(src[si+2] as u32);
dst[di+0] = ENCODING[(val >> 18 & 0x3F) as usize];
dst[di+1] = ENCODING[(val >> 12 & 0x3F) as usize];
dst[di+2] = ENCODING[(val >> 6 & 0x3F) as usize];
dst[di+3] = ENCODING[(val & 0x3F) as usize];
si += 3;
di += 4;
}
dst
}
The first striking thing about this Rust version's assembly [2] is the
increase in complexity. Where the C version had only three branches — the
assertion, the loop condition and a check that src_len
isn't zero — the Rust
version has eleven!
This is because the Rust compiler has automatically inserted some error
checking. This is generally good: for example, the compiler has inserted a check
that the memory allocation for the dst
vector succeeds. I did not bother to
check what malloc
returns in the C version, which would cause the program to
crash if it had failed.
It's great that Rust enforces such checks automatically... although you can't currently make it do anything other than abort the program if the check fails, so there's no chance to gracefully handle running out of memory. (There is a proposed API to let the programmer set their own Out-Of-Memory handler but it is currently experimental and not included in stable Rust releases [3].)
The Rust compiler has also inserted bounds checking for all of the array accesses. This is a nice thing in the general case — buffer overflows are the cause of many bugs in C code — but in this case it is a little annoying. We know the sizes of both arrays and we iterate through the arrays based on those sizes: there is no need for run-time bounds checking as we can prove at compile-time that we will never read or write outside the bounds of the arrays.
I had hoped that the bounds checks would be elided, since static analysis can prove that they are unnecessary, but it appears that in this case they are not. (The compiler options used are, at the time of writing, the same as those used by Cargo's default "release" profile.)
This is what we get:
; Setup and then...
test r14, r14 ; if (src_len == 0)
je .LBB8_18 ; goto .LBB8_18
mov ebx, 2 ; si = 2
mov esi, 3 ; di = 3
lea r8, [rip + .L__unnamed_7] ; ENCODING array
.LBB8_6:
lea rax, [rbx - 1]
cmp rax, r14 ; if (si - 1 >= src_len)
jae .LBB8_20 ; goto OUT_OF_BOUNDS
cmp rbx, r14 ; if (si >= src_len)
jae .LBB8_21 ; goto OUT_OF_BOUNDS
lea rax, [rsi - 3]
cmp r13, rax ; if (dst_len <= di - 3)
jbe .LBB8_9 ; goto OUT_OF_BOUNDS
movzx ecx, byte ptr [r12 + rbx - 2] ; load src[si-2]
movzx edx, byte ptr [r12 + rbx - 1] ; load src[si-1]
movzx edi, byte ptr [r12 + rbx] ; load src[si]
mov rax, rcx
shr rax, 2
movzx eax, byte ptr [rax + r8] ; load 1st symbol
mov byte ptr [r9 + rsi - 3], al ; store at dst[di-3]
lea rax, [rsi - 2]
cmp r13, rax ; if (dst_len <= di - 2)
jbe .LBB8_12 ; goto OUT_OF_BOUNDS
shl rdx, 8
shl ecx, 16
or ecx, edx
shr ecx, 12
and ecx, 63
movzx eax, byte ptr [rcx + r8] ; load 2nd symbol
mov byte ptr [r9 + rsi - 2], al ; store at dst[di-2]
lea rax, [rsi - 1]
cmp r13, rax ; if (dst_len <= di - 1)
jbe .LBB8_14 ; goto OUT_OF_BOUNDS
or rdx, rdi
shr edx, 6
and edx, 63
movzx eax, byte ptr [rdx + r8] ; load 3rd symbol
mov byte ptr [r9 + rsi - 1], al ; store at dst[di-1]
cmp r13, rsi ; if (dst_len <= di)
jbe .LBB8_16 ; goto OUT_OF_BOUNDS
and edi, 63
movzx eax, byte ptr [rdi + r8] ; load 4th symbol
mov byte ptr [r9 + rsi], al ; store at dst[di]
lea rax, [rbx + 3]
add rbx, 1
add rsi, 4 ; di += 4
cmp rbx, r14 ; if (si + 1 < src_len)
mov rbx, rax ; si = si + 3
jb .LBB8_6 ; goto .LBB8_6
.LBB8_18:
; The end
Note that the bit-twiddling parts are identical to the C version. I compiled the C version with Clang, which uses the same LLVM backend as the Rust compiler, so this shouldn't be too much of a surprise.
The main difference is that it doesn't increment a destination pointer
directly, as the C version did, but instead keeps a di
index. This is because
it needs to use di
when bounds checking accesses to the dst
array. So, apart
from the bounds checking, the Rust version and the C version are almost
identical.
The bounds checking probably won't make much difference in this case: the calculations are all register-only operations and hardware branch prediction will minimize any effects of the branches themselves. However, it does have an effect on register allocation and code size... and it just kind of irks me that they are there when we can see clearly at compile time that they aren't necessary: it would be nice if we could get rid of them.
Our Rust version also doesn't seem very idiomatic. One of the great promises of Rust is that it allows us to write fast, low-level code in a high-level, functional programming style. We shouldn't be manually incrementing array indices, we should be using iterators!
Here's another version:
pub fn encode(src: &[u8]) -> Vec<u8> {
assert!(src.len() % 3 == 0);
let dst_len = (src.len() / 3) * 4;
let mut dst = vec![0 as u8; dst_len];
for (s, d) in src.chunks_exact(3).zip(dst.chunks_exact_mut(4)) {
let val = (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32);
d[0] = ENCODING[(val >> 18 & 0x3F) as usize];
d[1] = ENCODING[(val >> 12 & 0x3F) as usize];
d[2] = ENCODING[(val >> 6 & 0x3F) as usize];
d[3] = ENCODING[(val & 0x3F) as usize];
}
dst
}
That one gives us this [4]:
; Setup and then...
test r8, r8 ; if (src_len/3 == 0)
je .LBB7_7 ; goto .LBB7_7
add r13, 2 ; src += 2
xor edx, edx ; i = 0
lea r9, [rip + .L__unnamed_7] ; ENCODING array
.LBB7_6:
movzx esi, byte ptr [r13 - 2] ; load src-2
movzx ecx, byte ptr [r13 - 1] ; load src-1
shl ecx, 8
mov edi, esi
shl edi, 16
or edi, ecx
movzx r10d, byte ptr [r13] ; load src
or ecx, r10d
shr rsi, 2
movzx ebx, byte ptr [rsi + r9] ; load 1st symbol
mov byte ptr [rax + 4*rdx], bl ; store at dst[4*i]
shr edi, 12
and edi, 63
movzx ebx, byte ptr [rdi + r9] ; load 2nd symbol
mov byte ptr [rax + 4*rdx + 1], bl ; store at dst[4*i+1]
shr ecx, 6
and ecx, 63
movzx ecx, byte ptr [rcx + r9] ; load 3rd symbol
mov byte ptr [rax + 4*rdx + 2], cl ; store at dst[4*i+2]
and r10d, 63
movzx ecx, byte ptr [r10 + r9] ; load 4th symbol
mov byte ptr [rax + 4*rdx + 3], cl ; store at dst[4*i+3]
lea rdx, [rdx + 1] ; i += 1
add r13, 3 ; src += 3
cmp rdx, r8 ; if (i < src_len/3)
jb .LBB7_6 ; goto .LBB7_6
.LBB7_7:
; The end
Hooray! No bounds checking! This one is now almost identical to the C version. We have one index and one array pointer that we increment directly.
Whether this version is as readable as the first version is perhaps a matter of taste — but we can definitely do better. We're still allocating an array and then iterating through it: if we're really going to do justice to Rust's OCaml heritage, we should be more functional still:
pub fn encode(src: &[u8]) -> Vec<u8> {
src.chunks_exact(3)
.map(|s| (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32))
.flat_map(|v| vec![ENCODING[(v >> 18 & 0x3F) as usize],
ENCODING[(v >> 12 & 0x3F) as usize],
ENCODING[(v >> 6 & 0x3F) as usize],
ENCODING[(v & 0x3F) as usize]])
.collect()
}
Unfortunately, that compiles to something far more convoluted [5]. The inner
loop has gone from 26 instructions to 115! Worse, the closures that we passed to
map
and flat_map
have been compiled to a separate function which we call on
every iteration.
And there's an even bigger problem caused by that vec!
macro in
the flat_map
closure. It's there because we need to return an
iterator and Rust does not provide iterators over arrays or tuples. We could
return a slice, but then we'd have an iterator over references. That's no good
because they would be references to the temporary variables created in that
closure, so they couldn't live longer than the closure itself, but we need them
later to collect them into the output. The only convenient option provided by
core Rust is a Vec
... but that results in memory allocation and
deallocation on every iteration!
We can fix this by making our own iterator that works with arrays. We need something that will take an array of four elements and will return a copy of each element on request:
struct ArrayIter<T> {
a: [T; 4],
i: usize,
}
impl<T> ArrayIter<T> {
fn new(a: [T; 4]) -> Self {
Self {a, i: 0}
}
}
impl<T> Iterator for ArrayIter<T>
where T: Copy,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.i < 4 {
let value = self.a[self.i];
self.i += 1;
Some(value)
} else {
None
}
}
}
pub fn encode(src: &[u8]) -> Vec<u8> {
src.chunks_exact(3)
.map(|s| (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32))
.flat_map(|v| ArrayIter::new(
[ENCODING[(v >> 18 & 0x3F) as usize],
ENCODING[(v >> 12 & 0x3F) as usize],
ENCODING[(v >> 6 & 0x3F) as usize],
ENCODING[(v & 0x3F) as usize]]))
.collect()
}
This is better [6]: we've eliminated the heap allocations and deallocations in
the flat_map
closure. But we've reintroduced bounds checking!
Also, we still have another memory-related problem: as we loop through the src
array, calling our map
/flat_map
closure on every iteration, we incrementally
populate a result Vec
. But unlike in our earlier version, where we allocated
dst
once at the beginning and then wrote our results into it, this version
writes into the result Vec
until it is full then it reallocates as needed. We
know exactly how many output symbols we are going to create but collect
does
not, so it must dynamically reallocate memory inside our loop.
Maybe going fully functional wasn't such a good idea. It's definitely not as efficient as just writing the loops explicitly.
Finally, let's finish by handling the case when the input is not a multiple of 3
bytes. In that case, the Base64 output should still be a multiple of 4 symbols:
we have to add padding symbols, usually =
, to make up the difference.
Here's a version in C:
char *encode(uint8_t *src, size_t src_len) {
size_t dst_len = (src_len + 2) / 3 * 4;
char *dst = malloc(dst_len);
size_t si = 0;
size_t di = 0;
size_t n = (src_len / 3) * 3;
while (si < n) {
uint32_t val = (src[si] << 16) | (src[si+1] << 8) | src[si+2];
dst[di+0] = encoding[val >> 18 & 0x3F];
dst[di+1] = encoding[val >> 12 & 0x3F];
dst[di+2] = encoding[val >> 6 & 0x3F];
dst[di+3] = encoding[val & 0x3F];
si += 3;
di += 4;
}
size_t remainder = src_len - si;
if (remainder == 0) {
return dst;
}
uint32_t val = src[si] << 16;
if (remainder == 2) {
val |= src[si+1] << 8;
}
dst[di+0] = encoding[val >> 18 & 0x3F];
dst[di+1] = encoding[val >> 12 & 0x3F];
switch (remainder) {
case 2:
dst[di+2] = encoding[val >> 6 & 0x3F];
break;
case 1:
dst[di+2] = '=';
break;
}
dst[di+3] = '=';
return dst;
}
It's the same as before but now we handle the remaining bytes when the input isn't a multiple of 3 bytes long.
We can do something similar in Rust, but we have to be careful to clone and borrow the iterators as appropriate:
pub fn encode(src: &[u8]) -> Vec<u8> {
let dst_len = (src.len() + 2) / 3 * 4;
let mut dst = vec![0 as u8; dst_len];
let src_iter = src.chunks_exact(3);
let src_remainder = src_iter.clone().remainder();
let mut dst_iter = dst.chunks_exact_mut(4);
for (s, d) in src_iter.zip(&mut dst_iter) {
let val = (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32);
d[0] = ENCODING[(val >> 18 & 0x3F) as usize];
d[1] = ENCODING[(val >> 12 & 0x3F) as usize];
d[2] = ENCODING[(val >> 6 & 0x3F) as usize];
d[3] = ENCODING[(val & 0x3F) as usize];
}
if src_remainder.len() > 0 {
let mut val = (src_remainder[0] as u32) << 16;
if src_remainder.len() > 1 {
val |= (src_remainder[1] as u32) << 8;
}
let dst_remainder = dst_iter.last().unwrap();
dst_remainder[0] = ENCODING[(val >> 18 & 0x3F) as usize];
dst_remainder[1] = ENCODING[(val >> 12 & 0x3F) as usize];
match src_remainder.len() {
2 => {
dst_remainder[2] = ENCODING[(val >> 6 & 0x3F) as usize];
},
1 => {
dst_remainder[2] = b'=';
},
_ => (),
}
dst_remainder[3] = b'=';
}
dst
}
As before, the C and the Rust produce pretty much equivalent assembly
[7][8]. The Rust version has a couple of extra checks (such as the unwrap
)
but it's very similar.
The Rust Book says that Rust strives to provide "higher-level features that compile to lower-level code as fast as code written manually" [9]. This appears to be the case... as long as you pick the right features.