use core::cell::{RefCell, UnsafeCell};
use core::convert::Infallible;
use core::future::Future;
use core::ops::Range;
use core::pin::Pin;
use core::task::{Context, Poll};
use crate::blocking_mutex::raw::RawMutex;
use crate::blocking_mutex::Mutex;
use crate::ring_buffer::RingBuffer;
use crate::waitqueue::WakerRegistration;
pub struct Writer<'p, M, const N: usize>
where
M: RawMutex,
{
pipe: &'p Pipe<M, N>,
}
impl<'p, M, const N: usize> Clone for Writer<'p, M, N>
where
M: RawMutex,
{
fn clone(&self) -> Self {
Writer { pipe: self.pipe }
}
}
impl<'p, M, const N: usize> Copy for Writer<'p, M, N> where M: RawMutex {}
impl<'p, M, const N: usize> Writer<'p, M, N>
where
M: RawMutex,
{
pub fn write<'a>(&'a self, buf: &'a [u8]) -> WriteFuture<'a, M, N> {
self.pipe.write(buf)
}
pub fn try_write(&self, buf: &[u8]) -> Result<usize, TryWriteError> {
self.pipe.try_write(buf)
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteFuture<'p, M, const N: usize>
where
M: RawMutex,
{
pipe: &'p Pipe<M, N>,
buf: &'p [u8],
}
impl<'p, M, const N: usize> Future for WriteFuture<'p, M, N>
where
M: RawMutex,
{
type Output = usize;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.pipe.try_write_with_context(Some(cx), self.buf) {
Ok(n) => Poll::Ready(n),
Err(TryWriteError::Full) => Poll::Pending,
}
}
}
impl<'p, M, const N: usize> Unpin for WriteFuture<'p, M, N> where M: RawMutex {}
pub struct Reader<'p, M, const N: usize>
where
M: RawMutex,
{
pipe: &'p Pipe<M, N>,
}
impl<'p, M, const N: usize> Reader<'p, M, N>
where
M: RawMutex,
{
pub fn read<'a>(&'a self, buf: &'a mut [u8]) -> ReadFuture<'a, M, N> {
self.pipe.read(buf)
}
pub fn try_read(&self, buf: &mut [u8]) -> Result<usize, TryReadError> {
self.pipe.try_read(buf)
}
pub fn fill_buf(&mut self) -> FillBufFuture<'_, M, N> {
FillBufFuture { pipe: Some(self.pipe) }
}
pub fn try_fill_buf(&mut self) -> Result<&[u8], TryReadError> {
unsafe { self.pipe.try_fill_buf_with_context(None) }
}
pub fn consume(&mut self, amt: usize) {
self.pipe.consume(amt)
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadFuture<'p, M, const N: usize>
where
M: RawMutex,
{
pipe: &'p Pipe<M, N>,
buf: &'p mut [u8],
}
impl<'p, M, const N: usize> Future for ReadFuture<'p, M, N>
where
M: RawMutex,
{
type Output = usize;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.pipe.try_read_with_context(Some(cx), self.buf) {
Ok(n) => Poll::Ready(n),
Err(TryReadError::Empty) => Poll::Pending,
}
}
}
impl<'p, M, const N: usize> Unpin for ReadFuture<'p, M, N> where M: RawMutex {}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct FillBufFuture<'p, M, const N: usize>
where
M: RawMutex,
{
pipe: Option<&'p Pipe<M, N>>,
}
impl<'p, M, const N: usize> Future for FillBufFuture<'p, M, N>
where
M: RawMutex,
{
type Output = &'p [u8];
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pipe = self.pipe.take().unwrap();
match unsafe { pipe.try_fill_buf_with_context(Some(cx)) } {
Ok(buf) => Poll::Ready(buf),
Err(TryReadError::Empty) => {
self.pipe = Some(pipe);
Poll::Pending
}
}
}
}
impl<'p, M, const N: usize> Unpin for FillBufFuture<'p, M, N> where M: RawMutex {}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TryReadError {
Empty,
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TryWriteError {
Full,
}
struct PipeState<const N: usize> {
buffer: RingBuffer<N>,
read_waker: WakerRegistration,
write_waker: WakerRegistration,
}
#[repr(transparent)]
struct Buffer<const N: usize>(UnsafeCell<[u8; N]>);
impl<const N: usize> Buffer<N> {
unsafe fn get<'a>(&self, r: Range<usize>) -> &'a [u8] {
let p = self.0.get() as *const u8;
core::slice::from_raw_parts(p.add(r.start), r.end - r.start)
}
unsafe fn get_mut<'a>(&self, r: Range<usize>) -> &'a mut [u8] {
let p = self.0.get() as *mut u8;
core::slice::from_raw_parts_mut(p.add(r.start), r.end - r.start)
}
}
unsafe impl<const N: usize> Send for Buffer<N> {}
unsafe impl<const N: usize> Sync for Buffer<N> {}
pub struct Pipe<M, const N: usize>
where
M: RawMutex,
{
buf: Buffer<N>,
inner: Mutex<M, RefCell<PipeState<N>>>,
}
impl<M, const N: usize> Pipe<M, N>
where
M: RawMutex,
{
pub const fn new() -> Self {
Self {
buf: Buffer(UnsafeCell::new([0; N])),
inner: Mutex::new(RefCell::new(PipeState {
buffer: RingBuffer::new(),
read_waker: WakerRegistration::new(),
write_waker: WakerRegistration::new(),
})),
}
}
fn lock<R>(&self, f: impl FnOnce(&mut PipeState<N>) -> R) -> R {
self.inner.lock(|rc| f(&mut *rc.borrow_mut()))
}
fn try_read_with_context(&self, cx: Option<&mut Context<'_>>, buf: &mut [u8]) -> Result<usize, TryReadError> {
self.inner.lock(|rc: &RefCell<PipeState<N>>| {
let s = &mut *rc.borrow_mut();
if s.buffer.is_full() {
s.write_waker.wake();
}
let available = unsafe { self.buf.get(s.buffer.pop_buf()) };
if available.is_empty() {
if let Some(cx) = cx {
s.read_waker.register(cx.waker());
}
return Err(TryReadError::Empty);
}
let n = available.len().min(buf.len());
buf[..n].copy_from_slice(&available[..n]);
s.buffer.pop(n);
Ok(n)
})
}
unsafe fn try_fill_buf_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<&[u8], TryReadError> {
self.inner.lock(|rc: &RefCell<PipeState<N>>| {
let s = &mut *rc.borrow_mut();
if s.buffer.is_full() {
s.write_waker.wake();
}
let available = unsafe { self.buf.get(s.buffer.pop_buf()) };
if available.is_empty() {
if let Some(cx) = cx {
s.read_waker.register(cx.waker());
}
return Err(TryReadError::Empty);
}
Ok(available)
})
}
fn consume(&self, amt: usize) {
self.inner.lock(|rc: &RefCell<PipeState<N>>| {
let s = &mut *rc.borrow_mut();
let available = s.buffer.pop_buf();
assert!(amt <= available.len());
s.buffer.pop(amt);
})
}
fn try_write_with_context(&self, cx: Option<&mut Context<'_>>, buf: &[u8]) -> Result<usize, TryWriteError> {
self.inner.lock(|rc: &RefCell<PipeState<N>>| {
let s = &mut *rc.borrow_mut();
if s.buffer.is_empty() {
s.read_waker.wake();
}
let available = unsafe { self.buf.get_mut(s.buffer.push_buf()) };
if available.is_empty() {
if let Some(cx) = cx {
s.write_waker.register(cx.waker());
}
return Err(TryWriteError::Full);
}
let n = available.len().min(buf.len());
available[..n].copy_from_slice(&buf[..n]);
s.buffer.push(n);
Ok(n)
})
}
pub fn split(&mut self) -> (Reader<'_, M, N>, Writer<'_, M, N>) {
(Reader { pipe: self }, Writer { pipe: self })
}
pub fn write<'a>(&'a self, buf: &'a [u8]) -> WriteFuture<'a, M, N> {
WriteFuture { pipe: self, buf }
}
pub async fn write_all(&self, mut buf: &[u8]) {
while !buf.is_empty() {
let n = self.write(buf).await;
buf = &buf[n..];
}
}
pub fn try_write(&self, buf: &[u8]) -> Result<usize, TryWriteError> {
self.try_write_with_context(None, buf)
}
pub fn read<'a>(&'a self, buf: &'a mut [u8]) -> ReadFuture<'a, M, N> {
ReadFuture { pipe: self, buf }
}
pub fn try_read(&self, buf: &mut [u8]) -> Result<usize, TryReadError> {
self.try_read_with_context(None, buf)
}
pub fn clear(&self) {
self.inner.lock(|rc: &RefCell<PipeState<N>>| {
let s = &mut *rc.borrow_mut();
s.buffer.clear();
s.write_waker.wake();
})
}
pub fn is_full(&self) -> bool {
self.len() == N
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
N
}
pub fn len(&self) -> usize {
self.lock(|c| c.buffer.len())
}
pub fn free_capacity(&self) -> usize {
N - self.len()
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::ErrorType for Pipe<M, N> {
type Error = Infallible;
}
impl<M: RawMutex, const N: usize> embedded_io_async::Read for Pipe<M, N> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
Ok(Pipe::read(self, buf).await)
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::Write for Pipe<M, N> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
Ok(Pipe::write(self, buf).await)
}
async fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::ErrorType for &Pipe<M, N> {
type Error = Infallible;
}
impl<M: RawMutex, const N: usize> embedded_io_async::Read for &Pipe<M, N> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
Ok(Pipe::read(self, buf).await)
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::Write for &Pipe<M, N> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
Ok(Pipe::write(self, buf).await)
}
async fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::ErrorType for Reader<'_, M, N> {
type Error = Infallible;
}
impl<M: RawMutex, const N: usize> embedded_io_async::Read for Reader<'_, M, N> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
Ok(Reader::read(self, buf).await)
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::BufRead for Reader<'_, M, N> {
async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
Ok(Reader::fill_buf(self).await)
}
fn consume(&mut self, amt: usize) {
Reader::consume(self, amt)
}
}
impl<M: RawMutex, const N: usize> embedded_io_async::ErrorType for Writer<'_, M, N> {
type Error = Infallible;
}
impl<M: RawMutex, const N: usize> embedded_io_async::Write for Writer<'_, M, N> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
Ok(Writer::write(self, buf).await)
}
async fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use futures_executor::ThreadPool;
use futures_util::task::SpawnExt;
use static_cell::StaticCell;
use super::*;
use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
#[test]
fn writing_once() {
let c = Pipe::<NoopRawMutex, 3>::new();
assert!(c.try_write(&[1]).is_ok());
assert_eq!(c.free_capacity(), 2);
}
#[test]
fn writing_when_full() {
let c = Pipe::<NoopRawMutex, 3>::new();
assert_eq!(c.try_write(&[42]), Ok(1));
assert_eq!(c.try_write(&[43]), Ok(1));
assert_eq!(c.try_write(&[44]), Ok(1));
assert_eq!(c.try_write(&[45]), Err(TryWriteError::Full));
assert_eq!(c.free_capacity(), 0);
}
#[test]
fn receiving_once_with_one_send() {
let c = Pipe::<NoopRawMutex, 3>::new();
assert!(c.try_write(&[42]).is_ok());
let mut buf = [0; 16];
assert_eq!(c.try_read(&mut buf), Ok(1));
assert_eq!(buf[0], 42);
assert_eq!(c.free_capacity(), 3);
}
#[test]
fn receiving_when_empty() {
let c = Pipe::<NoopRawMutex, 3>::new();
let mut buf = [0; 16];
assert_eq!(c.try_read(&mut buf), Err(TryReadError::Empty));
assert_eq!(c.free_capacity(), 3);
}
#[test]
fn simple_send_and_receive() {
let c = Pipe::<NoopRawMutex, 3>::new();
assert!(c.try_write(&[42]).is_ok());
let mut buf = [0; 16];
assert_eq!(c.try_read(&mut buf), Ok(1));
assert_eq!(buf[0], 42);
}
#[test]
fn read_buf() {
let mut c = Pipe::<NoopRawMutex, 3>::new();
let (mut r, w) = c.split();
assert!(w.try_write(&[42, 43]).is_ok());
let buf = r.try_fill_buf().unwrap();
assert_eq!(buf, &[42, 43]);
let buf = r.try_fill_buf().unwrap();
assert_eq!(buf, &[42, 43]);
r.consume(1);
let buf = r.try_fill_buf().unwrap();
assert_eq!(buf, &[43]);
r.consume(1);
assert_eq!(r.try_fill_buf(), Err(TryReadError::Empty));
assert_eq!(w.try_write(&[44, 45, 46]), Ok(1));
assert_eq!(w.try_write(&[45, 46]), Ok(2));
let buf = r.try_fill_buf().unwrap();
assert_eq!(buf, &[44]); r.consume(1);
let buf = r.try_fill_buf().unwrap();
assert_eq!(buf, &[45, 46]);
assert!(w.try_write(&[47]).is_ok());
let buf = r.try_fill_buf().unwrap();
assert_eq!(buf, &[45, 46, 47]);
r.consume(3);
}
#[test]
fn writer_is_cloneable() {
let mut c = Pipe::<NoopRawMutex, 3>::new();
let (_r, w) = c.split();
let _ = w.clone();
}
#[futures_test::test]
async fn receiver_receives_given_try_write_async() {
let executor = ThreadPool::new().unwrap();
static CHANNEL: StaticCell<Pipe<CriticalSectionRawMutex, 3>> = StaticCell::new();
let c = &*CHANNEL.init(Pipe::new());
let c2 = c;
let f = async move {
assert_eq!(c2.try_write(&[42]), Ok(1));
};
executor.spawn(f).unwrap();
let mut buf = [0; 16];
assert_eq!(c.read(&mut buf).await, 1);
assert_eq!(buf[0], 42);
}
#[futures_test::test]
async fn sender_send_completes_if_capacity() {
let c = Pipe::<CriticalSectionRawMutex, 1>::new();
c.write(&[42]).await;
let mut buf = [0; 16];
assert_eq!(c.read(&mut buf).await, 1);
assert_eq!(buf[0], 42);
}
}