diff options
Diffstat (limited to 'src/tftpd.rs')
| -rw-r--r-- | src/tftpd.rs | 145 |
1 files changed, 105 insertions, 40 deletions
diff --git a/src/tftpd.rs b/src/tftpd.rs index 2237761..50b8033 100644 --- a/src/tftpd.rs +++ b/src/tftpd.rs @@ -1,44 +1,70 @@ use std::net::{SocketAddr,UdpSocket}; use std::fs::File; +use std::io; use std::io::prelude::*; -fn handle_wrq(_cl: &SocketAddr, _buf: &[u8]) { +fn handle_wrq(_cl: &SocketAddr, _buf: &[u8]) -> Result<(), io::Error> { + Ok(()) } -fn wait_for_ack(sock: &UdpSocket, expected_block: u16) { +fn wait_for_ack(sock: &UdpSocket, expected_block: u16) -> Result<bool, io::Error> { let mut buf = [0; 4]; - sock.recv(&mut buf).expect("recv"); + sock.recv(&mut buf)?; let opcode = u16::from_be_bytes([buf[0], buf[1]]); let block_nr = u16::from_be_bytes([buf[2], buf[3]]); - if opcode != 4 { - // error - } - if block_nr != expected_block { - // error + if opcode == 4 && block_nr == expected_block { + return Ok(true) } + + Ok(false) } -fn send_file(cl: &SocketAddr, filename: &str) { - let mut file = File::open(filename).expect("open"); +fn send_file(cl: &SocketAddr, filename: &str) -> Result<(), io::Error> { + let file = File::open(filename); + let mut file = match file { + Ok(f) => f, + Err(ref error) if error.kind() == io::ErrorKind::NotFound => { + handle_error(cl, 1, "File not found")?; + return Err(io::Error::new(io::ErrorKind::NotFound, "file not found")); + }, + Err(_) => { + handle_error(cl, 2, "Permission denied")?; + return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")); + } + }; - let socket = UdpSocket::bind("0.0.0.0:0").expect("bind"); - socket.connect(cl).expect("connect"); + let socket = UdpSocket::bind("0.0.0.0:0")?; + socket.connect(cl)?; let mut block_nr: u16 = 1; loop { let mut filebuf = [0; 512]; - let n = file.read(&mut filebuf).expect("read"); + let len = file.read(&mut filebuf); + let len = match len { + Ok(n) => n, + Err(ref error) if error.kind() == io::ErrorKind::Interrupted => continue, /* retry */ + Err(err) => { + handle_error(cl, 0, "File reading error")?; + return Err(err); + } + }; let mut sendbuf = vec![0x00, 0x03]; // opcode sendbuf.extend(block_nr.to_be_bytes().iter()); - sendbuf.extend(filebuf[0..n].iter()); - - socket.send(&sendbuf).expect("send"); - wait_for_ack(&socket, block_nr); + sendbuf.extend(filebuf[0..len].iter()); + + socket.send(&sendbuf)?; + for _ in 1..5 { + match wait_for_ack(&socket, block_nr) { + Ok(true) => break, + Ok(false) => continue, + Err(e) => return Err(e), + }; + } - if n < 512 { + if len < 512 { /* this was the last block */ break; } @@ -46,60 +72,99 @@ fn send_file(cl: &SocketAddr, filename: &str) { /* increment with rollover on overflow */ block_nr = block_nr.wrapping_add(1); } + Ok(()) +} + +fn file_allowed(_filename: &str) -> bool { + // TODO + true } -fn handle_rrq(cl: &SocketAddr, buf: &[u8]) { +fn handle_rrq(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { let mut iter = buf.iter(); - let fname_len = iter.position(|&x| x == 0).expect("not found"); + let dataerr = io::Error::new(io::ErrorKind::InvalidData, "invalid data received"); + + let fname_len = iter.position(|&x| x == 0); + let fname_len = match fname_len { + Some(len) => len, + None => return Err(dataerr), + }; let fname_begin = 0; let fname_end = fname_begin + fname_len; - let filename = String::from_utf8(buf[fname_begin .. fname_end].to_vec()).expect("str"); - - let mode_len = iter.position(|&x| x == 0).expect("not found"); + let filename = String::from_utf8(buf[fname_begin .. fname_end].to_vec()); + let filename = match filename { + Ok(fname) => fname, + Err(_) => return Err(dataerr), + }; + + let mode_len = iter.position(|&x| x == 0); + let mode_len = match mode_len { + Some(len) => len, + None => return Err(dataerr), + }; let mode_begin = fname_end + 1; let mode_end = mode_begin + mode_len; - let mode = String::from_utf8(buf[mode_begin .. mode_end].to_vec()).expect("str"); - let mode = mode.to_lowercase(); + let mode = String::from_utf8(buf[mode_begin .. mode_end].to_vec()); + let mode = match mode { + Ok(m) => m.to_lowercase(), + Err(_) => return Err(dataerr), + }; match mode.as_ref() { - "octet" => println!("octet mode"), - _ => handle_error(cl, 0, "Unsupported mode"), + "octet" => (), + _ => handle_error(cl, 0, "Unsupported mode")?, } - println!("Sending {} to {}", filename, cl); - send_file(&cl, &filename); + match file_allowed(&filename) { + true => (), + false => { + handle_error(cl, 2, "Permission denied")?; + return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")); + } + } + + match send_file(&cl, &filename) { + Ok(_) => println!("Sent {} to {}.", filename, cl), + Err(_) => println!("Sending {} to {} failed.", filename, cl), + } + Ok(()) } -fn handle_error(cl: &SocketAddr, code: u16, msg: &str) { - let socket = UdpSocket::bind("0.0.0.0:0").expect("bind"); - socket.connect(cl).expect("connect"); +fn handle_error(cl: &SocketAddr, code: u16, msg: &str) -> Result<(), io::Error> { + let socket = UdpSocket::bind("0.0.0.0:0")?; + socket.connect(cl)?; let mut buf = vec![0x00, 0x05]; // opcode buf.extend(code.to_be_bytes().iter()); buf.extend(msg.as_bytes()); - socket.send(&buf).expect("send"); + socket.send(&buf)?; + Ok(()) } -fn handle_client(cl: &SocketAddr, buf: &[u8]) { +fn handle_client(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { let opcode = u16::from_be_bytes([buf[0], buf[1]]); match opcode { - 1 /* RRQ */ => handle_rrq(&cl, &buf[2..]), - 2 /* WRQ */ => handle_wrq(&cl, &buf[2..]), + 1 /* RRQ */ => handle_rrq(&cl, &buf[2..])?, + 2 /* WRQ */ => handle_wrq(&cl, &buf[2..])?, 5 /* ERROR */ => println!("Received ERROR from {}", cl), - _ => handle_error(cl, 4, "Unexpected opcode"), + _ => handle_error(cl, 4, "Unexpected opcode")?, } + Ok(()) } fn main() { - let socket = UdpSocket::bind("127.0.0.1:12345").expect("bind"); + let socket = UdpSocket::bind("127.0.0.1:12345").expect("Binding a socket failed."); loop { let mut buf = [0; 2048]; - let (n, src) = socket.recv_from(&mut buf).expect("recv"); + let (n, src) = socket.recv_from(&mut buf).expect("Receiving from the socket failed."); - handle_client(&src, &buf[0..n]); + match handle_client(&src, &buf[0..n]) { + /* errors intentionally ignored */ + _ => (), + } } } |
