summaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs166
1 files changed, 151 insertions, 15 deletions
diff --git a/src/lib.rs b/src/lib.rs
index f733552..4142e7f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -20,7 +20,7 @@ pub static VERSION: Option<&str> = option_env!("CARGO_PKG_VERSION");
type ProgressCallback = fn(cur: u64, total: u64, state: u64) -> u64;
#[repr(u16)]
-pub enum Opcodes {
+pub enum Opcode {
RRQ = 0x01,
WRQ = 0x02,
DATA = 0x03,
@@ -30,6 +30,13 @@ pub enum Opcodes {
}
#[derive(Clone, Copy)]
+#[repr(u8)]
+pub enum Mode {
+ OCTET,
+ NETASCII,
+}
+
+#[derive(Clone, Copy)]
pub struct TftpOptions {
blksize: usize,
timeout: u8,
@@ -39,6 +46,7 @@ pub struct TftpOptions {
#[derive(Clone, Copy)]
pub struct Tftp {
options: TftpOptions,
+ mode: Mode,
progress_cb: Option<ProgressCallback>,
}
@@ -50,10 +58,46 @@ fn default_options() -> TftpOptions {
}
}
+fn netascii_to_octet(buf: &[u8], previous_cr: bool) -> (Vec<u8>, bool) {
+ let mut out = Vec::with_capacity(buf.len());
+
+ let mut prev_cr = previous_cr;
+ for b in buf {
+ match *b {
+ b'\r' => {
+ if prev_cr {
+ out.push(b'\r');
+ }
+ prev_cr = true;
+ continue;
+ }
+ b'\0' if prev_cr => out.push(b'\r'),
+ b'\n' if prev_cr => out.push(b'\n'),
+ _ => out.push(*b),
+ }
+ prev_cr = false;
+ }
+ (out, prev_cr)
+}
+
+fn octet_to_netascii(buf: &[u8]) -> Vec<u8> {
+ let mut out = Vec::with_capacity(2 * buf.len());
+
+ for b in buf {
+ match *b {
+ b'\r' => out.extend(b"\r\0"),
+ b'\n' => out.extend(b"\r\n"),
+ _ => out.push(*b),
+ }
+ }
+ out
+}
+
impl Default for Tftp {
fn default() -> Tftp {
Tftp {
options: default_options(),
+ mode: Mode::OCTET,
progress_cb: None,
}
}
@@ -64,6 +108,36 @@ impl Tftp {
Default::default()
}
+ pub fn transfersize(&self, file: &mut File) -> Result<u64, io::Error> {
+ match self.mode {
+ Mode::OCTET => return Ok(file.metadata().expect("failed to get metadata").len()),
+ Mode::NETASCII => {},
+ }
+
+ let mut total_size = 0;
+ loop {
+ let mut buf = [0; 4096];
+ let size = match file.read(&mut buf) {
+ Ok(0) => break,
+ Ok(s) => s,
+ Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
+ Err(err) => return Err(err),
+ };
+ total_size += size as u64;
+ /* each \r and \n will take two bytes in netascii output */
+ total_size += buf[0..size].iter()
+ .filter(|&x| *x == b'\r' || *x == b'\n')
+ .count() as u64;
+ }
+
+ file.seek(io::SeekFrom::Start(0))?;
+ Ok(total_size)
+ }
+
+ pub fn set_mode(&mut self, mode: Mode) {
+ self.mode = mode;
+ }
+
fn get_tftp_str(&self, buf: &[u8]) -> Option<String> {
let mut iter = buf.iter();
@@ -106,7 +180,7 @@ impl Tftp {
}
let opcode = u16::from_be_bytes([buf[0], buf[1]]);
- if opcode != Opcodes::ERROR as u16 {
+ if opcode != Opcode::ERROR as u16 {
return std::io::Error::new(kind, error);
}
@@ -147,9 +221,9 @@ impl Tftp {
let opcode = u16::from_be_bytes([buf[0], buf[1]]);
let block_nr = u16::from_be_bytes([buf[2], buf[3]]);
- if opcode == Opcodes::ACK as u16 && block_nr == expected_block {
+ if opcode == Opcode::ACK as u16 && block_nr == expected_block {
return Ok(true);
- } else if opcode == Opcodes::ERROR as u16 {
+ } else if opcode == Opcode::ERROR as u16 {
return Err(self.parse_error(&buf[4..]));
}
@@ -166,7 +240,7 @@ impl Tftp {
}
let mut buf = Vec::with_capacity(512);
- buf.extend((Opcodes::OACK as u16).to_be_bytes().iter());
+ buf.extend((Opcode::OACK as u16).to_be_bytes().iter());
for (key, val) in options {
self.append_option(&mut buf, key, val);
@@ -271,7 +345,7 @@ impl Tftp {
pub fn send_error(&self, socket: &UdpSocket, code: u16, msg: &str) -> Result<(), io::Error> {
let mut buf = Vec::with_capacity(512);
- buf.extend((Opcodes::ERROR as u16).to_be_bytes().iter());
+ buf.extend((Opcode::ERROR as u16).to_be_bytes().iter());
buf.extend(code.to_be_bytes().iter());
buf.extend(msg.as_bytes());
@@ -281,7 +355,7 @@ impl Tftp {
fn _send_ack(&self, sock: &UdpSocket, cl: Option<SocketAddr>, block_nr: u16) -> Result<(), io::Error> {
let mut buf = Vec::with_capacity(4);
- buf.extend((Opcodes::ACK as u16).to_be_bytes().iter());
+ buf.extend((Opcode::ACK as u16).to_be_bytes().iter());
buf.extend(block_nr.to_be_bytes().iter());
match cl {
@@ -305,9 +379,12 @@ impl Tftp {
let mut prog_update = 0;
let tsize = self.transfer_size(file);
+ /* holds bytes from netascii conversion that did not fit in tx buffer */
+ let mut overflow = Vec::with_capacity(2 * self.options.blksize);
+
loop {
- let mut filebuf = vec![0; self.options.blksize];
- let len = match file.read(&mut filebuf) {
+ let mut filebuf = vec![0; self.options.blksize - overflow.len()];
+ let mut len = match file.read(&mut filebuf) {
Ok(n) => n,
Err(ref error) if error.kind() == io::ErrorKind::Interrupted => continue, /* retry */
Err(err) => {
@@ -316,10 +393,26 @@ impl Tftp {
}
};
+ /* take care of netascii conversion */
+ let mut databuf = filebuf[0..len].to_vec();
+ match self.mode {
+ Mode::OCTET => {},
+ Mode::NETASCII => {
+ overflow.extend(octet_to_netascii(&databuf));
+ databuf = overflow.clone();
+ if overflow.len() > self.options.blksize {
+ overflow = databuf.split_off(self.options.blksize);
+ } else {
+ overflow.clear();
+ }
+ len = databuf.len();
+ }
+ }
+
let mut sendbuf = Vec::with_capacity(4 + len);
- sendbuf.extend((Opcodes::DATA as u16).to_be_bytes().iter());
+ sendbuf.extend((Opcode::DATA as u16).to_be_bytes().iter());
sendbuf.extend(block_nr.to_be_bytes().iter());
- sendbuf.extend(filebuf[0..len].iter());
+ sendbuf.extend(databuf.iter());
let mut acked = false;
for _ in 1..5 {
@@ -359,6 +452,7 @@ impl Tftp {
let mut block_nr: u16 = 1;
let mut prog_update = 0;
let mut transferred = 0;
+ let mut netascii_state = false;
let tsize = self.transfer_size(file);
loop {
@@ -383,8 +477,8 @@ impl Tftp {
}
match u16::from_be_bytes([buf[0], buf[1]]) { // opcode
- opc if opc == Opcodes::DATA as u16 => (),
- opc if opc == Opcodes::ERROR as u16 => return Err(self.parse_error(&buf[..len])),
+ opc if opc == Opcode::DATA as u16 => (),
+ opc if opc == Opcode::ERROR as u16 => return Err(self.parse_error(&buf[..len])),
_ => return Err(io::Error::new(io::ErrorKind::Other, "unexpected opcode")),
};
if u16::from_be_bytes([buf[2], buf[3]]) != block_nr {
@@ -393,8 +487,16 @@ impl Tftp {
continue;
}
- let databuf = &buf[4..len];
- file.write_all(databuf)?;
+ let mut databuf = buf[4..len].to_vec();
+ match self.mode {
+ Mode::OCTET => {},
+ Mode::NETASCII => {
+ let (converted, state) = netascii_to_octet(&databuf, netascii_state);
+ databuf = converted;
+ netascii_state = state;
+ }
+ }
+ file.write_all(&databuf)?;
transferred += (len - 4) as u64;
if let Some(cb) = self.progress_cb {
@@ -409,6 +511,11 @@ impl Tftp {
}
}
+ if netascii_state {
+ /* the file ended with an incomplete \r encoding */
+ file.write(&[b'\r'])?;
+ }
+
file.flush()?;
Ok(())
@@ -492,4 +599,33 @@ mod tests {
tftp.append_option(&mut buf, "key", "value");
assert_eq!(buf, "key\x00value\x00".as_bytes());
}
+
+ #[test]
+ fn test_netascii_to_octet() {
+ assert_eq!(netascii_to_octet(b"\r\nfoo\r\0bar", false), (b"\nfoo\rbar".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\r\0", false), (b"\r".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\r\n", false), (b"\n".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"", false), (b"".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\n\0\n\0", false), (b"\n\0\n\0".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\r\r\n", false), (b"\r\n".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\r\n\r\n", false), (b"\n\n".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"test\r\0", false), (b"test\r".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"test\r", false), (b"test".to_vec(), true));
+ assert_eq!(netascii_to_octet(b"\r", false), (b"".to_vec(), true));
+ assert_eq!(netascii_to_octet(b"\0test", true), (b"\rtest".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\ntest", true), (b"\ntest".to_vec(), false));
+ assert_eq!(netascii_to_octet(b"\n\r", true), (b"\n".to_vec(), true));
+ assert_eq!(netascii_to_octet(b"", true), (b"".to_vec(), true));
+ assert_eq!(netascii_to_octet(b"\r", true), (b"\r".to_vec(), true));
+ }
+
+ #[test]
+ fn test_octet_to_netascii() {
+ assert_eq!(octet_to_netascii(b"foobar"), b"foobar");
+ assert_eq!(octet_to_netascii(b"foo\rbar\n"), b"foo\r\0bar\r\n");
+ assert_eq!(octet_to_netascii(b"\r\n"), b"\r\0\r\n");
+ assert_eq!(octet_to_netascii(b"\r\r\n\n"), b"\r\0\r\0\r\n\r\n");
+ assert_eq!(octet_to_netascii(b"\r\0\r\n"), b"\r\0\0\r\0\r\n");
+ assert_eq!(octet_to_netascii(b""), b"");
+ }
}