diff options
| -rw-r--r-- | src/tftp.rs | 64 | ||||
| -rw-r--r-- | src/tftpd.rs | 18 |
2 files changed, 51 insertions, 31 deletions
diff --git a/src/tftp.rs b/src/tftp.rs index 67dcd1a..3319bc4 100644 --- a/src/tftp.rs +++ b/src/tftp.rs @@ -12,15 +12,25 @@ use std::io; use std::io::prelude::*; pub struct TftpOptions { - + blksize: usize, } pub struct Tftp { + options: TftpOptions, +} + +fn default_options() -> TftpOptions { + TftpOptions { + blksize: 512, + } } impl Tftp { + pub fn new() -> Tftp { - Tftp{} + Tftp{ + options: default_options(), + } } fn get_tftp_str(&self, buf: &[u8]) -> Option<(String, usize)> { @@ -60,6 +70,10 @@ impl Tftp { fn ack_options(&self, sock: &UdpSocket, options: &HashMap<String, String>, ackwait: bool) -> Result<(), io::Error> { if options.is_empty() { + if !ackwait { + /* it's a WRQ, send normal ack to start transfer */ + self.send_ack(&sock, 0)?; + } return Ok(()) } @@ -88,13 +102,19 @@ impl Tftp { Err(io::Error::new(io::ErrorKind::TimedOut, "ack timeout")) } - pub fn init_tftp_options(&self, sock: &UdpSocket, options: &mut HashMap<String, String>, ackwait: bool) -> Result<TftpOptions, io::Error> { - let tftpopts = TftpOptions {}; + pub fn init_tftp_options(&mut self, sock: &UdpSocket, options: &mut HashMap<String, String>, ackwait: bool) -> Result<(), io::Error> { + self.options = default_options(); - options.retain(|key, _val| { + options.retain(|key, val| { match key.as_str() { - "placeholder_option" => { - true + "blksize" => { + match val.parse() { + Ok(b) if b >= 8 && b <= 65464 => { + self.options.blksize = b; + true + } + _ => false, + } } _ => false } @@ -102,7 +122,7 @@ impl Tftp { self.ack_options(&sock, &options, ackwait)?; - return Ok(tftpopts); + return Ok(()); } fn parse_options(&self, buf: &[u8]) -> HashMap<String, String> { @@ -189,7 +209,7 @@ impl Tftp { let mut block_nr: u16 = 1; loop { - let mut filebuf = [0; 512]; + let mut filebuf = vec![0; self.options.blksize]; let len = match file.read(&mut filebuf) { Ok(n) => n, Err(ref error) if error.kind() == io::ErrorKind::Interrupted => continue, /* retry */ @@ -222,7 +242,7 @@ impl Tftp { return Err(io::Error::new(io::ErrorKind::TimedOut, "ack timeout")) } - if len < 512 { + if len < self.options.blksize { /* this was the last block */ break; } @@ -242,26 +262,26 @@ impl Tftp { Err(_) => return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")), }; - let mut block_nr = 0; + let mut block_nr: u16 = 1; loop { - let mut buf = [0; 1024]; + let mut buf = vec![0; 4 + self.options.blksize + 1]; // +1 for later size check let mut len = 0; for _ in 1..5 { - self.send_ack(&sock, block_nr)?; len = match sock.recv(&mut buf) { Ok(n) => n, Err(ref error) if [io::ErrorKind::WouldBlock, io::ErrorKind::TimedOut].contains(&error.kind()) => { - /* re-ack and try to recv again */ + /* re-ack previous and try to recv again */ + self.send_ack(&sock, block_nr - 1)?; continue; } Err(err) => return Err(err), }; break; } - if len > 516 || len < 4 { - /* max size: 2 + 2 + 512 */ + if len < 4 || len > 4 + self.options.blksize { + /* max size: 2 + 2 + blksize */ return Err(io::Error::new(io::ErrorKind::InvalidInput, "unexpected size")); } @@ -269,25 +289,25 @@ impl Tftp { 3 /* DATA */ => (), _ => return Err(io::Error::new(io::ErrorKind::Other, "unexpected opcode")), }; - let nr = u16::from_be_bytes([buf[2], buf[3]]); - if nr != block_nr.wrapping_add(1) { + if u16::from_be_bytes([buf[2], buf[3]]) != block_nr { /* already received or packets were missed, re-acknowledge */ + self.send_ack(&sock, block_nr - 1)?; continue; } - block_nr = nr; let databuf = &buf[4..len]; file.write_all(databuf)?; - if len < 516 { + self.send_ack(&sock, block_nr)?; + block_nr = block_nr.wrapping_add(1); + + if len < 4 + self.options.blksize { break; } } file.flush()?; - self.send_ack(&sock, block_nr)?; - Ok(()) } diff --git a/src/tftpd.rs b/src/tftpd.rs index afb8205..dd3d875 100644 --- a/src/tftpd.rs +++ b/src/tftpd.rs @@ -71,9 +71,9 @@ impl Tftpd { } - fn handle_wrq(&self, socket: &UdpSocket, cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { + fn handle_wrq(&mut self, socket: &UdpSocket, cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { let (filename, mode, mut options) = self.tftp.parse_file_mode_options(buf)?; - let _opts = self.tftp.init_tftp_options(&socket, &mut options, false); + self.tftp.init_tftp_options(&socket, &mut options, false)?; match mode.as_ref() { "octet" => (), @@ -107,9 +107,9 @@ impl Tftpd { Ok(()) } - fn handle_rrq(&self, socket: &UdpSocket, cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { + fn handle_rrq(&mut self, socket: &UdpSocket, cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { let (filename, mode, mut options) = self.tftp.parse_file_mode_options(buf)?; - let _opts = self.tftp.init_tftp_options(&socket, &mut options, true); + self.tftp.init_tftp_options(&socket, &mut options, true)?; match mode.as_ref() { "octet" => (), @@ -135,14 +135,14 @@ impl Tftpd { Ok(()) } - pub fn handle_client(&self, conf: &Configuration, cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { + pub fn handle_client(&mut self, cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { let socket = UdpSocket::bind("0.0.0.0:0")?; socket.connect(cl)?; socket.set_read_timeout(Some(Duration::from_secs(5)))?; let _opcode = match u16::from_be_bytes([buf[0], buf[1]]) { 1 /* RRQ */ => { - if conf.wo { + if self.conf.wo { self.tftp.send_error(&socket, 4, "reading not allowed")?; return Err(io::Error::new(io::ErrorKind::Other, "unallowed mode")); } else { @@ -150,7 +150,7 @@ impl Tftpd { } }, 2 /* WRQ */ => { - if conf.ro { + if self.conf.ro { self.tftp.send_error(&socket, 4, "writing not allowed")?; return Err(io::Error::new(io::ErrorKind::Other, "unallowed mode")); } else { @@ -189,7 +189,7 @@ impl Tftpd { Ok(()) } - pub fn start(&self) { + pub fn start(&mut self) { let socket = match UdpSocket::bind(format!("0.0.0.0:{}", self.conf.port)) { Ok(s) => s, Err(err) => { @@ -223,7 +223,7 @@ impl Tftpd { } }; - match self.handle_client(&self.conf, &src, &buf[0..n]) { + match self.handle_client(&src, &buf[0..n]) { /* errors intentionally ignored */ _ => (), } |
