aboutsummaryrefslogtreecommitdiff
path: root/src/tftpd.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tftpd.rs')
-rw-r--r--src/tftpd.rs145
1 files changed, 105 insertions, 40 deletions
diff --git a/src/tftpd.rs b/src/tftpd.rs
index 2237761..50b8033 100644
--- a/src/tftpd.rs
+++ b/src/tftpd.rs
@@ -1,44 +1,70 @@
use std::net::{SocketAddr,UdpSocket};
use std::fs::File;
+use std::io;
use std::io::prelude::*;
-fn handle_wrq(_cl: &SocketAddr, _buf: &[u8]) {
+fn handle_wrq(_cl: &SocketAddr, _buf: &[u8]) -> Result<(), io::Error> {
+ Ok(())
}
-fn wait_for_ack(sock: &UdpSocket, expected_block: u16) {
+fn wait_for_ack(sock: &UdpSocket, expected_block: u16) -> Result<bool, io::Error> {
let mut buf = [0; 4];
- sock.recv(&mut buf).expect("recv");
+ sock.recv(&mut buf)?;
let opcode = u16::from_be_bytes([buf[0], buf[1]]);
let block_nr = u16::from_be_bytes([buf[2], buf[3]]);
- if opcode != 4 {
- // error
- }
- if block_nr != expected_block {
- // error
+ if opcode == 4 && block_nr == expected_block {
+ return Ok(true)
}
+
+ Ok(false)
}
-fn send_file(cl: &SocketAddr, filename: &str) {
- let mut file = File::open(filename).expect("open");
+fn send_file(cl: &SocketAddr, filename: &str) -> Result<(), io::Error> {
+ let file = File::open(filename);
+ let mut file = match file {
+ Ok(f) => f,
+ Err(ref error) if error.kind() == io::ErrorKind::NotFound => {
+ handle_error(cl, 1, "File not found")?;
+ return Err(io::Error::new(io::ErrorKind::NotFound, "file not found"));
+ },
+ Err(_) => {
+ handle_error(cl, 2, "Permission denied")?;
+ return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied"));
+ }
+ };
- let socket = UdpSocket::bind("0.0.0.0:0").expect("bind");
- socket.connect(cl).expect("connect");
+ let socket = UdpSocket::bind("0.0.0.0:0")?;
+ socket.connect(cl)?;
let mut block_nr: u16 = 1;
loop {
let mut filebuf = [0; 512];
- let n = file.read(&mut filebuf).expect("read");
+ let len = file.read(&mut filebuf);
+ let len = match len {
+ Ok(n) => n,
+ Err(ref error) if error.kind() == io::ErrorKind::Interrupted => continue, /* retry */
+ Err(err) => {
+ handle_error(cl, 0, "File reading error")?;
+ return Err(err);
+ }
+ };
let mut sendbuf = vec![0x00, 0x03]; // opcode
sendbuf.extend(block_nr.to_be_bytes().iter());
- sendbuf.extend(filebuf[0..n].iter());
-
- socket.send(&sendbuf).expect("send");
- wait_for_ack(&socket, block_nr);
+ sendbuf.extend(filebuf[0..len].iter());
+
+ socket.send(&sendbuf)?;
+ for _ in 1..5 {
+ match wait_for_ack(&socket, block_nr) {
+ Ok(true) => break,
+ Ok(false) => continue,
+ Err(e) => return Err(e),
+ };
+ }
- if n < 512 {
+ if len < 512 {
/* this was the last block */
break;
}
@@ -46,60 +72,99 @@ fn send_file(cl: &SocketAddr, filename: &str) {
/* increment with rollover on overflow */
block_nr = block_nr.wrapping_add(1);
}
+ Ok(())
+}
+
+fn file_allowed(_filename: &str) -> bool {
+ // TODO
+ true
}
-fn handle_rrq(cl: &SocketAddr, buf: &[u8]) {
+fn handle_rrq(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> {
let mut iter = buf.iter();
- let fname_len = iter.position(|&x| x == 0).expect("not found");
+ let dataerr = io::Error::new(io::ErrorKind::InvalidData, "invalid data received");
+
+ let fname_len = iter.position(|&x| x == 0);
+ let fname_len = match fname_len {
+ Some(len) => len,
+ None => return Err(dataerr),
+ };
let fname_begin = 0;
let fname_end = fname_begin + fname_len;
- let filename = String::from_utf8(buf[fname_begin .. fname_end].to_vec()).expect("str");
-
- let mode_len = iter.position(|&x| x == 0).expect("not found");
+ let filename = String::from_utf8(buf[fname_begin .. fname_end].to_vec());
+ let filename = match filename {
+ Ok(fname) => fname,
+ Err(_) => return Err(dataerr),
+ };
+
+ let mode_len = iter.position(|&x| x == 0);
+ let mode_len = match mode_len {
+ Some(len) => len,
+ None => return Err(dataerr),
+ };
let mode_begin = fname_end + 1;
let mode_end = mode_begin + mode_len;
- let mode = String::from_utf8(buf[mode_begin .. mode_end].to_vec()).expect("str");
- let mode = mode.to_lowercase();
+ let mode = String::from_utf8(buf[mode_begin .. mode_end].to_vec());
+ let mode = match mode {
+ Ok(m) => m.to_lowercase(),
+ Err(_) => return Err(dataerr),
+ };
match mode.as_ref() {
- "octet" => println!("octet mode"),
- _ => handle_error(cl, 0, "Unsupported mode"),
+ "octet" => (),
+ _ => handle_error(cl, 0, "Unsupported mode")?,
}
- println!("Sending {} to {}", filename, cl);
- send_file(&cl, &filename);
+ match file_allowed(&filename) {
+ true => (),
+ false => {
+ handle_error(cl, 2, "Permission denied")?;
+ return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied"));
+ }
+ }
+
+ match send_file(&cl, &filename) {
+ Ok(_) => println!("Sent {} to {}.", filename, cl),
+ Err(_) => println!("Sending {} to {} failed.", filename, cl),
+ }
+ Ok(())
}
-fn handle_error(cl: &SocketAddr, code: u16, msg: &str) {
- let socket = UdpSocket::bind("0.0.0.0:0").expect("bind");
- socket.connect(cl).expect("connect");
+fn handle_error(cl: &SocketAddr, code: u16, msg: &str) -> Result<(), io::Error> {
+ let socket = UdpSocket::bind("0.0.0.0:0")?;
+ socket.connect(cl)?;
let mut buf = vec![0x00, 0x05]; // opcode
buf.extend(code.to_be_bytes().iter());
buf.extend(msg.as_bytes());
- socket.send(&buf).expect("send");
+ socket.send(&buf)?;
+ Ok(())
}
-fn handle_client(cl: &SocketAddr, buf: &[u8]) {
+fn handle_client(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> {
let opcode = u16::from_be_bytes([buf[0], buf[1]]);
match opcode {
- 1 /* RRQ */ => handle_rrq(&cl, &buf[2..]),
- 2 /* WRQ */ => handle_wrq(&cl, &buf[2..]),
+ 1 /* RRQ */ => handle_rrq(&cl, &buf[2..])?,
+ 2 /* WRQ */ => handle_wrq(&cl, &buf[2..])?,
5 /* ERROR */ => println!("Received ERROR from {}", cl),
- _ => handle_error(cl, 4, "Unexpected opcode"),
+ _ => handle_error(cl, 4, "Unexpected opcode")?,
}
+ Ok(())
}
fn main() {
- let socket = UdpSocket::bind("127.0.0.1:12345").expect("bind");
+ let socket = UdpSocket::bind("127.0.0.1:12345").expect("Binding a socket failed.");
loop {
let mut buf = [0; 2048];
- let (n, src) = socket.recv_from(&mut buf).expect("recv");
+ let (n, src) = socket.recv_from(&mut buf).expect("Receiving from the socket failed.");
- handle_client(&src, &buf[0..n]);
+ match handle_client(&src, &buf[0..n]) {
+ /* errors intentionally ignored */
+ _ => (),
+ }
}
}