This blog post will not be very rigorous, but I wanted to highlight an example that I often go to regarding complicated async-iterators.
First, I should provide context.
Currently, we have an AsyncIterator
trait defined as:
pub trait AsyncIterator {
type Item;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Option<Self::Item>>;
}
This is a mashup of the Future
and Iterator
traits:
pub trait Future {
type Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>;
}
pub trait Iterator {
type Item;
fn next(&mut self) -> Option<Self::Item>;
}
This AsyncIterator
trait has been working fine for many Rust developers,
as exposed by the equivalent futures::Stream
trait.
However, some people would like AsyncIterator
to be implemented using async fn next(&mut self)
.
This sounds like it simplifies some things. Of course, we couldn't have a trait implemented with async functions before, but we can now. So it is worth considering.
There is already a blog post by withoutboats explaining why this is a bad idea, so I won't re-hash any of those here.
The reason why you might want to have an async fn next()
is that it supposedly makes it easier to implement, but I disagree.
To demonstrate, let me show you an example.
In the tokio-util
crate, there is an AsyncIterator
called FramedRead
. This will read from bytes from an AsyncRead
,
and output some decoded frames. The implementation of this is a sight to behold:
impl<T, U, R> Stream for FramedImpl<T, U, R>
where
T: AsyncRead,
U: Decoder,
R: BorrowMut<ReadFrame>,
{
type Item = Result<U::Item, U::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use crate::util::poll_read_buf;
let mut pinned = self.project();
let state: &mut ReadFrame = pinned.state.borrow_mut();
// The following loops implements a state machine with each state corresponding
// to a combination of the `is_readable` and `eof` flags. States persist across
// loop entries and most state transitions occur with a return.
//
// The initial state is `reading`.
//
// | state | eof | is_readable | has_errored |
// |---------|-------|-------------|-------------|
// | reading | false | false | false |
// | framing | false | true | false |
// | pausing | true | true | false |
// | paused | true | false | false |
// | errored | <any> | <any> | true |
// `decode_eof` returns Err
// ┌────────────────────────────────────────────────────────┐
// `decode_eof` returns │ │
// `Ok(Some)` │ │
// ┌─────┐ │ `decode_eof` returns After returning │
// Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
// ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
// │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
// Pending read │ │ │ │ │ │
// ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
// │ │ │ ┌──────┐ │ Pending │ │
// │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
// └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
// └──┬─▲────┘ └─────┬──┬┘ │ │
// │ │ │ │ `decode` returns Err │ │
// │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
// │ read returns Err │
// └────────────────────────────────────────────────────────────────────────────────────────────┘
loop {
// Return `None` if we have encountered an error from the underlying decoder
// See: https://github.com/tokio-rs/tokio/issues/3976
if state.has_errored {
// preparing has_errored -> paused
trace!("Returning None and setting paused");
state.is_readable = false;
state.has_errored = false;
return Poll::Ready(None);
}
// Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
// i.e. it _might_ contain data consumable as a frame or closing frame.
// Both signal that there is no such data by returning `None`.
//
// If `decode` couldn't read a frame and the upstream source has returned eof,
// `decode_eof` will attempt to decode the remaining bytes as closing frames.
//
// If the underlying AsyncRead is resumable, we may continue after an EOF,
// but must finish emitting all of it's associated `decode_eof` frames.
// Furthermore, we don't want to emit any `decode_eof` frames on retried
// reads after an EOF unless we've actually read more data.
if state.is_readable {
// pausing or framing
if state.eof {
// pausing
let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
trace!("Got an error, going to errored state");
state.has_errored = true;
err
})?;
if frame.is_none() {
state.is_readable = false; // prepare pausing -> paused
}
// implicit pausing -> pausing or pausing -> paused
return Poll::Ready(frame.map(Ok));
}
// framing
trace!("attempting to decode a frame");
if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
trace!("Got an error, going to errored state");
state.has_errored = true;
op
})? {
trace!("frame decoded from buffer");
// implicit framing -> framing
return Poll::Ready(Some(Ok(frame)));
}
// framing -> reading
state.is_readable = false;
}
// reading or paused
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
state.buffer.reserve(1);
let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
|err| {
trace!("Got an error, going to errored state");
state.has_errored = true;
err
},
)? {
Poll::Ready(ct) => ct,
// implicit reading -> reading or implicit paused -> paused
Poll::Pending => return Poll::Pending,
};
if bytect == 0 {
if state.eof {
// We're already at an EOF, and since we've reached this path
// we're also not readable. This implies that we've already finished
// our `decode_eof` handling, so we can simply return `None`.
// implicit paused -> paused
return Poll::Ready(None);
}
// prepare reading -> paused
state.eof = true;
} else {
// prepare paused -> framing or noop reading -> framing
state.eof = false;
}
// paused -> framing or reading -> framing or reading -> pausing
state.is_readable = true;
}
}
}
Wow! All it does is read some bytes, decode the frames, and then repeat until there is an error or no more bytes.
Why is it so complicated? Your first instinct might be to blame the fact that we must use poll
ing and not .await
.
However, if you look closely at this implementation, we only poll
the reader in one place. If we were to replace that with
await
, the implementation would look almost identical. Here is a diff:
-impl<T, U, R> Stream for FramedImpl<T, U, R>
+impl<T, U, R> AsyncIterator for FramedImpl<T, U, R>
where
- T: AsyncRead,
+ T: AsyncRead + Unpin,
U: Decoder,
R: BorrowMut<ReadFrame>,
{
type Item = Result<U::Item, U::Error>;
- fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ async fn next(&mut self) -> Option<Self::Item> {
use crate::util::poll_read_buf;
- let mut pinned = self.project();
- let state: &mut ReadFrame = self.pinned.borrow_mut();
+ let state: &mut ReadFrame = self.state.borrow_mut();
// The following loops implements a state machine with each state corresponding
// to a combination of the `is_readable` and `eof` flags. States persist across
// loop entries and most state transitions occur with a return.
//
// The initial state is `reading`.
//
// | state | eof | is_readable | has_errored |
// |---------|-------|-------------|-------------|
// | reading | false | false | false |
// | framing | false | true | false |
// | pausing | true | true | false |
// | paused | true | false | false |
// | errored | <any> | <any> | true |
// `decode_eof` returns Err
// ┌────────────────────────────────────────────────────────┐
// `decode_eof` returns │ │
// `Ok(Some)` │ │
// ┌─────┐ │ `decode_eof` returns After returning │
// Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
// ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
// │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
// Pending read │ │ │ │ │ │
// ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
// │ │ │ ┌──────┐ │ Pending │ │
// │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
// └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
// └──┬─▲────┘ └─────┬──┬┘ │ │
// │ │ │ │ `decode` returns Err │ │
// │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
// │ read returns Err │
// └────────────────────────────────────────────────────────────────────────────────────────────┘
loop {
// Return `None` if we have encountered an error from the underlying decoder
// See: https://github.com/tokio-rs/tokio/issues/3976
if state.has_errored {
// preparing has_errored -> paused
trace!("Returning None and setting paused");
state.is_readable = false;
state.has_errored = false;
- return Poll::Ready(None);
+ return None;
}
// Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
// i.e. it _might_ contain data consumable as a frame or closing frame.
// Both signal that there is no such data by returning `None`.
//
// If `decode` couldn't read a frame and the upstream source has returned eof,
// `decode_eof` will attempt to decode the remaining bytes as closing frames.
//
// If the underlying AsyncRead is resumable, we may continue after an EOF,
// but must finish emitting all of it's associated `decode_eof` frames.
// Furthermore, we don't want to emit any `decode_eof` frames on retried
// reads after an EOF unless we've actually read more data.
if state.is_readable {
// pausing or framing
if state.eof {
// pausing
let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
trace!("Got an error, going to errored state");
state.has_errored = true;
err
})?;
if frame.is_none() {
state.is_readable = false; // prepare pausing -> paused
}
// implicit pausing -> pausing or pausing -> paused
- return Poll::Ready(frame.map(Ok));
+ return frame.map(Ok);
}
// framing
trace!("attempting to decode a frame");
if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
trace!("Got an error, going to errored state");
state.has_errored = true;
op
})? {
trace!("frame decoded from buffer");
// implicit framing -> framing
- return Poll::Ready(Some(Ok(frame)));
+ return Some(Ok(frame));
}
// framing -> reading
state.is_readable = false;
}
// reading or paused
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
state.buffer.reserve(1);
- let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
+ let bytect = match self.inner.next(&mut state.buffer).await.map_err(
|err| {
trace!("Got an error, going to errored state");
state.has_errored = true;
err
},
- )? {
- Poll::Ready(ct) => ct,
- // implicit reading -> reading or implicit paused -> paused
- Poll::Pending => return Poll::Pending,
- };
+ )?;
if bytect == 0 {
if state.eof {
// We're already at an EOF, and since we've reached this path
// we're also not readable. This implies that we've already finished
// our `decode_eof` handling, so we can simply return `None`.
// implicit paused -> paused
- return Poll::Ready(None);
+ return None;
}
// prepare reading -> paused
state.eof = true;
} else {
// prepare paused -> framing or noop reading -> framing
state.eof = false;
}
// paused -> framing or reading -> framing or reading -> pausing
state.is_readable = true;
}
}
}
Is that really all we can do? Unfortunately, yes. Reading bytes and decoding them is complex.
Sometimes we must repeat calls to read
to get a single frame. Sometimes a single call to read
contains multiple frames.
This makes this state machine construction still necessary.
We've also introduced a regression by needing to have T: Unpin
. The proposed solution here is to use a PinnedAsyncIterator
trait
but I don't think that really makes anything simpler.
Instead of worrying about what the AsyncIterator trait looks like, we should provide a convenient syntax to implement it instead.
Here's how I imagine the implementation above could look written as a generator.
#[yield(U::Item)]
async fn framed_impl<T, U>(t: T, codec: U) -> Result<(), U::Error>
where
T: AsyncRead,
U: Decoder,
{
let mut inner = std::pin::pin!(t);
let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY);
loop {
buffer.reserve(1);
let eof = inner.read(&mut buffer).await? == 0;
'framing: loop {
trace!("attempting to decode a frame");
let frame = if eof {
codec.decode_eof(&mut buffer)?
} else {
codec.decode(&mut buffer)?
};
match frame {
Some(frame) => {
trace!("frame decoded from buffer");
yield Ok(frame);
continue 'framing;
}
// nothing left to read. goodbye
None if eof => return Ok(()),
// go back to reading
None => break 'framing,
};
}
}
}
Do I need to say more?