use crate::{iter::IterableByOverlaps, ReadStorage, Region, Storage};
pub trait NorFlashError: core::fmt::Debug {
	fn kind(&self) -> NorFlashErrorKind;
}
impl NorFlashError for core::convert::Infallible {
	fn kind(&self) -> NorFlashErrorKind {
		match *self {}
	}
}
pub trait ErrorType {
	type Error: NorFlashError;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[non_exhaustive]
pub enum NorFlashErrorKind {
	NotAligned,
	OutOfBounds,
	Other,
}
impl NorFlashError for NorFlashErrorKind {
	fn kind(&self) -> NorFlashErrorKind {
		*self
	}
}
impl core::fmt::Display for NorFlashErrorKind {
	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
		match self {
			Self::NotAligned => write!(f, "Arguments are not properly aligned"),
			Self::OutOfBounds => write!(f, "Arguments are out of bounds"),
			Self::Other => write!(f, "An implementation specific error occurred"),
		}
	}
}
pub trait ReadNorFlash: ErrorType {
	const READ_SIZE: usize;
	fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error>;
	fn capacity(&self) -> usize;
}
pub fn check_read<T: ReadNorFlash>(
	flash: &T,
	offset: u32,
	length: usize,
) -> Result<(), NorFlashErrorKind> {
	check_slice(flash, T::READ_SIZE, offset, length)
}
pub trait NorFlash: ReadNorFlash {
	const WRITE_SIZE: usize;
	const ERASE_SIZE: usize;
	fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error>;
	fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error>;
}
pub fn check_erase<T: NorFlash>(flash: &T, from: u32, to: u32) -> Result<(), NorFlashErrorKind> {
	let (from, to) = (from as usize, to as usize);
	if from > to || to > flash.capacity() {
		return Err(NorFlashErrorKind::OutOfBounds);
	}
	if from % T::ERASE_SIZE != 0 || to % T::ERASE_SIZE != 0 {
		return Err(NorFlashErrorKind::NotAligned);
	}
	Ok(())
}
pub fn check_write<T: NorFlash>(
	flash: &T,
	offset: u32,
	length: usize,
) -> Result<(), NorFlashErrorKind> {
	check_slice(flash, T::WRITE_SIZE, offset, length)
}
fn check_slice<T: ReadNorFlash>(
	flash: &T,
	align: usize,
	offset: u32,
	length: usize,
) -> Result<(), NorFlashErrorKind> {
	let offset = offset as usize;
	if length > flash.capacity() || offset > flash.capacity() - length {
		return Err(NorFlashErrorKind::OutOfBounds);
	}
	if offset % align != 0 || length % align != 0 {
		return Err(NorFlashErrorKind::NotAligned);
	}
	Ok(())
}
impl<T: ErrorType> ErrorType for &mut T {
	type Error = T::Error;
}
impl<T: ReadNorFlash> ReadNorFlash for &mut T {
	const READ_SIZE: usize = T::READ_SIZE;
	fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
		T::read(self, offset, bytes)
	}
	fn capacity(&self) -> usize {
		T::capacity(self)
	}
}
impl<T: NorFlash> NorFlash for &mut T {
	const WRITE_SIZE: usize = T::WRITE_SIZE;
	const ERASE_SIZE: usize = T::ERASE_SIZE;
	fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
		T::erase(self, from, to)
	}
	fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
		T::write(self, offset, bytes)
	}
}
pub trait MultiwriteNorFlash: NorFlash {}
struct Page {
	pub start: u32,
	pub size: usize,
}
impl Page {
	fn new(index: u32, size: usize) -> Self {
		Self {
			start: index * size as u32,
			size,
		}
	}
	const fn end(&self) -> u32 {
		self.start + self.size as u32
	}
}
impl Region for Page {
	fn contains(&self, address: u32) -> bool {
		(self.start <= address) && (self.end() > address)
	}
}
pub struct RmwNorFlashStorage<'a, S> {
	storage: S,
	merge_buffer: &'a mut [u8],
}
impl<'a, S> RmwNorFlashStorage<'a, S>
where
	S: NorFlash,
{
	pub fn new(nor_flash: S, merge_buffer: &'a mut [u8]) -> Self {
		if merge_buffer.len() < S::ERASE_SIZE {
			panic!("Merge buffer is too small");
		}
		Self {
			storage: nor_flash,
			merge_buffer,
		}
	}
}
impl<'a, S> ReadStorage for RmwNorFlashStorage<'a, S>
where
	S: ReadNorFlash,
{
	type Error = S::Error;
	fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
		self.storage.read(offset, bytes)
	}
	fn capacity(&self) -> usize {
		self.storage.capacity()
	}
}
impl<'a, S> Storage for RmwNorFlashStorage<'a, S>
where
	S: NorFlash,
{
	fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
		let last_page = self.storage.capacity() / S::ERASE_SIZE;
		for (data, page, addr) in (0..last_page as u32)
			.map(move |i| Page::new(i, S::ERASE_SIZE))
			.overlaps(bytes, offset)
		{
			let offset_into_page = addr.saturating_sub(page.start) as usize;
			self.storage
				.read(page.start, &mut self.merge_buffer[..S::ERASE_SIZE])?;
			self.storage.erase(page.start, page.end())?;
			self.merge_buffer[..S::ERASE_SIZE]
				.iter_mut()
				.skip(offset_into_page)
				.zip(data)
				.for_each(|(byte, input)| *byte = *input);
			self.storage
				.write(page.start, &self.merge_buffer[..S::ERASE_SIZE])?;
		}
		Ok(())
	}
}
pub struct RmwMultiwriteNorFlashStorage<'a, S> {
	storage: S,
	merge_buffer: &'a mut [u8],
}
impl<'a, S> RmwMultiwriteNorFlashStorage<'a, S>
where
	S: MultiwriteNorFlash,
{
	pub fn new(nor_flash: S, merge_buffer: &'a mut [u8]) -> Self {
		if merge_buffer.len() < S::ERASE_SIZE {
			panic!("Merge buffer is too small");
		}
		Self {
			storage: nor_flash,
			merge_buffer,
		}
	}
}
impl<'a, S> ReadStorage for RmwMultiwriteNorFlashStorage<'a, S>
where
	S: ReadNorFlash,
{
	type Error = S::Error;
	fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
		self.storage.read(offset, bytes)
	}
	fn capacity(&self) -> usize {
		self.storage.capacity()
	}
}
impl<'a, S> Storage for RmwMultiwriteNorFlashStorage<'a, S>
where
	S: MultiwriteNorFlash,
{
	fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
		let last_page = self.storage.capacity() / S::ERASE_SIZE;
		for (data, page, addr) in (0..last_page as u32)
			.map(move |i| Page::new(i, S::ERASE_SIZE))
			.overlaps(bytes, offset)
		{
			let offset_into_page = addr.saturating_sub(page.start) as usize;
			self.storage
				.read(page.start, &mut self.merge_buffer[..S::ERASE_SIZE])?;
			let rhs = &self.merge_buffer[offset_into_page..S::ERASE_SIZE];
			let is_subset = data.iter().zip(rhs.iter()).all(|(a, b)| *a & *b == *a);
			if is_subset {
				let offset = addr as usize % S::WRITE_SIZE;
				let aligned_end = data.len() % S::WRITE_SIZE + offset + data.len();
				self.merge_buffer[..aligned_end].fill(0xff);
				self.merge_buffer[offset..offset + data.len()].copy_from_slice(data);
				self.storage
					.write(addr - offset as u32, &self.merge_buffer[..aligned_end])?;
			} else {
				self.storage.erase(page.start, page.end())?;
				self.merge_buffer[..S::ERASE_SIZE]
					.iter_mut()
					.skip(offset_into_page)
					.zip(data)
					.for_each(|(byte, input)| *byte = *input);
				self.storage
					.write(page.start, &self.merge_buffer[..S::ERASE_SIZE])?;
			}
		}
		Ok(())
	}
}