lib.rs 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /// Parser for the MNIST handwriting recognition data set.
  2. ///
  3. /// http://yann.lecun.com/exdb/mnist/
  4. extern crate byteorder;
  5. use std::io::Result;
  6. use std::io::Read;
  7. use byteorder::ReadBytesExt;
  8. #[derive(Debug, PartialEq)]
  9. enum DataType {
  10. UnsignedByte,
  11. SignedByte,
  12. Short,
  13. Int,
  14. Float,
  15. Double
  16. }
  17. #[derive(Debug)]
  18. pub struct Idx {
  19. data_type: DataType,
  20. pub dimensions: Vec<u32>,
  21. }
  22. fn magic<T: Read>(src: &mut T) -> Result<[u8; 4]> {
  23. let mut limit = src.take(4);
  24. let mut header = [0; 4];
  25. let bytes_read = try!(limit.read(&mut header));
  26. assert!(bytes_read == 4);
  27. assert!(header[0] == 0);
  28. assert!(header[1] == 0);
  29. Ok(header)
  30. }
  31. pub fn header<T: Read>(src: &mut T) -> Result<Idx> {
  32. let header = try!(magic(src));
  33. let data_type = match header[2] {
  34. 0x08 => DataType::UnsignedByte,
  35. 0x09 => DataType::SignedByte,
  36. 0x0b => DataType::Short,
  37. 0x0c => DataType::Int,
  38. 0x0d => DataType::Float,
  39. 0x0f => DataType::Double,
  40. v => panic!(format!("Unknown data type {} in header!", v)),
  41. };
  42. let mut dim = Vec::new();
  43. for _ in 0..header[3] {
  44. let size = try!(src.read_u32::<byteorder::BigEndian>());
  45. dim.push(size);
  46. }
  47. Ok(Idx {
  48. data_type: data_type,
  49. dimensions: dim,
  50. })
  51. }
  52. #[test]
  53. fn valid_magic() {
  54. let mut test = std::io::Cursor::new(vec![0,0,8,3]);
  55. let header = magic(&mut test).unwrap();
  56. assert!(header.len() == 4);
  57. assert!(header[0] == 0);
  58. assert!(header[1] == 0);
  59. assert!(header[2] == 8);
  60. assert!(header[3] == 3);
  61. }
  62. #[test]
  63. fn valid_header() {
  64. let mut test = std::io::Cursor::new(vec![0,0,0x0c,2,0,0,0,11,0,1,0,1]);
  65. let header = header(&mut test).unwrap();
  66. assert!(header.data_type == DataType::Int);
  67. assert!(header.dimensions.len() == 2);
  68. assert!(header.dimensions[0] == 11);
  69. assert!(header.dimensions[1] == 65537);
  70. }