use crate::{iter::IterableByOverlaps, ReadStorage, Region, Storage};
pub trait NorFlashError: core::fmt::Debug {
fn kind(&self) -> NorFlashErrorKind;
}
impl NorFlashError for core::convert::Infallible {
fn kind(&self) -> NorFlashErrorKind {
match *self {}
}
}
pub trait ErrorType {
type Error: NorFlashError;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[non_exhaustive]
pub enum NorFlashErrorKind {
NotAligned,
OutOfBounds,
Other,
}
impl NorFlashError for NorFlashErrorKind {
fn kind(&self) -> NorFlashErrorKind {
*self
}
}
impl core::fmt::Display for NorFlashErrorKind {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::NotAligned => write!(f, "Arguments are not properly aligned"),
Self::OutOfBounds => write!(f, "Arguments are out of bounds"),
Self::Other => write!(f, "An implementation specific error occurred"),
}
}
}
pub trait ReadNorFlash: ErrorType {
const READ_SIZE: usize;
fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error>;
fn capacity(&self) -> usize;
}
pub fn check_read<T: ReadNorFlash>(
flash: &T,
offset: u32,
length: usize,
) -> Result<(), NorFlashErrorKind> {
check_slice(flash, T::READ_SIZE, offset, length)
}
pub trait NorFlash: ReadNorFlash {
const WRITE_SIZE: usize;
const ERASE_SIZE: usize;
fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error>;
fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error>;
}
pub fn check_erase<T: NorFlash>(flash: &T, from: u32, to: u32) -> Result<(), NorFlashErrorKind> {
let (from, to) = (from as usize, to as usize);
if from > to || to > flash.capacity() {
return Err(NorFlashErrorKind::OutOfBounds);
}
if from % T::ERASE_SIZE != 0 || to % T::ERASE_SIZE != 0 {
return Err(NorFlashErrorKind::NotAligned);
}
Ok(())
}
pub fn check_write<T: NorFlash>(
flash: &T,
offset: u32,
length: usize,
) -> Result<(), NorFlashErrorKind> {
check_slice(flash, T::WRITE_SIZE, offset, length)
}
fn check_slice<T: ReadNorFlash>(
flash: &T,
align: usize,
offset: u32,
length: usize,
) -> Result<(), NorFlashErrorKind> {
let offset = offset as usize;
if length > flash.capacity() || offset > flash.capacity() - length {
return Err(NorFlashErrorKind::OutOfBounds);
}
if offset % align != 0 || length % align != 0 {
return Err(NorFlashErrorKind::NotAligned);
}
Ok(())
}
impl<T: ErrorType> ErrorType for &mut T {
type Error = T::Error;
}
impl<T: ReadNorFlash> ReadNorFlash for &mut T {
const READ_SIZE: usize = T::READ_SIZE;
fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
T::read(self, offset, bytes)
}
fn capacity(&self) -> usize {
T::capacity(self)
}
}
impl<T: NorFlash> NorFlash for &mut T {
const WRITE_SIZE: usize = T::WRITE_SIZE;
const ERASE_SIZE: usize = T::ERASE_SIZE;
fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
T::erase(self, from, to)
}
fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
T::write(self, offset, bytes)
}
}
pub trait MultiwriteNorFlash: NorFlash {}
struct Page {
pub start: u32,
pub size: usize,
}
impl Page {
fn new(index: u32, size: usize) -> Self {
Self {
start: index * size as u32,
size,
}
}
const fn end(&self) -> u32 {
self.start + self.size as u32
}
}
impl Region for Page {
fn contains(&self, address: u32) -> bool {
(self.start <= address) && (self.end() > address)
}
}
pub struct RmwNorFlashStorage<'a, S> {
storage: S,
merge_buffer: &'a mut [u8],
}
impl<'a, S> RmwNorFlashStorage<'a, S>
where
S: NorFlash,
{
pub fn new(nor_flash: S, merge_buffer: &'a mut [u8]) -> Self {
if merge_buffer.len() < S::ERASE_SIZE {
panic!("Merge buffer is too small");
}
Self {
storage: nor_flash,
merge_buffer,
}
}
}
impl<'a, S> ReadStorage for RmwNorFlashStorage<'a, S>
where
S: ReadNorFlash,
{
type Error = S::Error;
fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
self.storage.read(offset, bytes)
}
fn capacity(&self) -> usize {
self.storage.capacity()
}
}
impl<'a, S> Storage for RmwNorFlashStorage<'a, S>
where
S: NorFlash,
{
fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
let last_page = self.storage.capacity() / S::ERASE_SIZE;
for (data, page, addr) in (0..last_page as u32)
.map(move |i| Page::new(i, S::ERASE_SIZE))
.overlaps(bytes, offset)
{
let offset_into_page = addr.saturating_sub(page.start) as usize;
self.storage
.read(page.start, &mut self.merge_buffer[..S::ERASE_SIZE])?;
self.storage.erase(page.start, page.end())?;
self.merge_buffer[..S::ERASE_SIZE]
.iter_mut()
.skip(offset_into_page)
.zip(data)
.for_each(|(byte, input)| *byte = *input);
self.storage
.write(page.start, &self.merge_buffer[..S::ERASE_SIZE])?;
}
Ok(())
}
}
pub struct RmwMultiwriteNorFlashStorage<'a, S> {
storage: S,
merge_buffer: &'a mut [u8],
}
impl<'a, S> RmwMultiwriteNorFlashStorage<'a, S>
where
S: MultiwriteNorFlash,
{
pub fn new(nor_flash: S, merge_buffer: &'a mut [u8]) -> Self {
if merge_buffer.len() < S::ERASE_SIZE {
panic!("Merge buffer is too small");
}
Self {
storage: nor_flash,
merge_buffer,
}
}
}
impl<'a, S> ReadStorage for RmwMultiwriteNorFlashStorage<'a, S>
where
S: ReadNorFlash,
{
type Error = S::Error;
fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
self.storage.read(offset, bytes)
}
fn capacity(&self) -> usize {
self.storage.capacity()
}
}
impl<'a, S> Storage for RmwMultiwriteNorFlashStorage<'a, S>
where
S: MultiwriteNorFlash,
{
fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
let last_page = self.storage.capacity() / S::ERASE_SIZE;
for (data, page, addr) in (0..last_page as u32)
.map(move |i| Page::new(i, S::ERASE_SIZE))
.overlaps(bytes, offset)
{
let offset_into_page = addr.saturating_sub(page.start) as usize;
self.storage
.read(page.start, &mut self.merge_buffer[..S::ERASE_SIZE])?;
let rhs = &self.merge_buffer[offset_into_page..S::ERASE_SIZE];
let is_subset = data.iter().zip(rhs.iter()).all(|(a, b)| *a & *b == *a);
if is_subset {
let offset = addr as usize % S::WRITE_SIZE;
let aligned_end = data.len() % S::WRITE_SIZE + offset + data.len();
self.merge_buffer[..aligned_end].fill(0xff);
self.merge_buffer[offset..offset + data.len()].copy_from_slice(data);
self.storage
.write(addr - offset as u32, &self.merge_buffer[..aligned_end])?;
} else {
self.storage.erase(page.start, page.end())?;
self.merge_buffer[..S::ERASE_SIZE]
.iter_mut()
.skip(offset_into_page)
.zip(data)
.for_each(|(byte, input)| *byte = *input);
self.storage
.write(page.start, &self.merge_buffer[..S::ERASE_SIZE])?;
}
}
Ok(())
}
}