diff options
| -rw-r--r-- | src/tftpd.rs | 208 |
1 files changed, 163 insertions, 45 deletions
diff --git a/src/tftpd.rs b/src/tftpd.rs index a0b718f..bac27e6 100644 --- a/src/tftpd.rs +++ b/src/tftpd.rs @@ -19,10 +19,6 @@ struct Configuration { gid: u32, } -fn handle_wrq(_cl: &SocketAddr, _buf: &[u8]) -> Result<(), io::Error> { - Ok(()) -} - fn wait_for_ack(sock: &UdpSocket, expected_block: u16) -> Result<bool, io::Error> { let mut buf = [0; 4]; match sock.recv(&mut buf) { @@ -43,21 +39,51 @@ fn wait_for_ack(sock: &UdpSocket, expected_block: u16) -> Result<bool, io::Error Ok(false) } +fn parse_file_mode(buf: &[u8]) -> Result<(PathBuf, String), io::Error> { + let mut iter = buf.iter(); + + let dataerr = io::Error::new(io::ErrorKind::InvalidData, "invalid data received"); + + let fname_len = match iter.position(|&x| x == 0) { + Some(len) => len, + None => return Err(dataerr), + }; + let fname_begin = 0; + let fname_end = fname_begin + fname_len; + let filename = match String::from_utf8(buf[fname_begin .. fname_end].to_vec()) { + Ok(fname) => fname, + Err(_) => return Err(dataerr), + }; + let filename = Path::new(&filename); + + let mode_len = match iter.position(|&x| x == 0) { + Some(len) => len, + None => return Err(dataerr), + }; + let mode_begin = fname_end + 1; + let mode_end = mode_begin + mode_len; + let mode = match String::from_utf8(buf[mode_begin .. mode_end].to_vec()) { + Ok(m) => m.to_lowercase(), + Err(_) => return Err(dataerr), + }; + + Ok((filename.to_path_buf(), mode)) +} + fn send_file(cl: &SocketAddr, path: &Path) -> Result<(), io::Error> { - let file = File::open(path); - let mut file = match file { + let mut file = match File::open(path) { Ok(f) => f, Err(ref error) if error.kind() == io::ErrorKind::NotFound => { - handle_error(cl, 1, "File not found")?; + send_error(cl, 1, "File not found")?; return Err(io::Error::new(io::ErrorKind::NotFound, "file not found")); }, Err(_) => { - handle_error(cl, 2, "Permission denied")?; + send_error(cl, 2, "Permission denied")?; return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")); } }; if !file.metadata()?.is_file() { - handle_error(cl, 1, "File not found")?; + send_error(cl, 1, "File not found")?; return Err(io::Error::new(io::ErrorKind::NotFound, "file not found")); } @@ -73,7 +99,7 @@ fn send_file(cl: &SocketAddr, path: &Path) -> Result<(), io::Error> { Ok(n) => n, Err(ref error) if error.kind() == io::ErrorKind::Interrupted => continue, /* retry */ Err(err) => { - handle_error(cl, 0, "File reading error")?; + send_error(cl, 0, "File reading error")?; return Err(err); } }; @@ -104,12 +130,79 @@ fn send_file(cl: &SocketAddr, path: &Path) -> Result<(), io::Error> { Ok(()) } +fn recv_file(sock: &UdpSocket, path: &PathBuf) -> Result<(), io::Error> { + let mut file = match File::create(path) { + Ok(f) => f, + Err(_) => return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")), + }; + + let mut block_nr = 0; + + loop { + let mut buf = [0; 1024]; + let mut len = 0; + + for _ in 1..5 { + send_ack(&sock, block_nr)?; + len = match sock.recv(&mut buf) { + Ok(n) => n, + Err(ref error) if [io::ErrorKind::WouldBlock, io::ErrorKind::TimedOut].contains(&error.kind()) => { + /* re-ack and try to recv again */ + continue; + } + Err(err) => return Err(err), + }; + } + if len > 516 || len < 4 { + /* max size: 2 + 2 + 512 */ + return Err(io::Error::new(io::ErrorKind::InvalidInput, "unexpected size")); + } + + let _opcode = match u16::from_be_bytes([buf[0], buf[1]]) { + 3 /* DATA */ => (), + _ => return Err(io::Error::new(io::ErrorKind::Other, "unexpected opcode")), + }; + let nr = u16::from_be_bytes([buf[2], buf[3]]); + if nr != block_nr.wrapping_add(1) { + /* already received or packets were missed, re-acknowledge */ + continue; + } + block_nr = nr; + + let databuf = &buf[4..len]; + file.write_all(databuf)?; + + if len < 516 { + break; + } + } + + file.flush()?; + + send_ack(&sock, block_nr)?; + + Ok(()) +} + fn file_allowed(filename: &Path) -> Option<PathBuf> { - let path = match filename.canonicalize() { + /* get parent to check dir where file should be read/written */ + let path = Path::new(".").join(filename); + let path = match path.parent() { + Some(p) => p, + None => return None, + }; + let path = match path.canonicalize() { Ok(p) => p, Err(_) => return None, }; + /* get last component to append to canonicalized path */ + let filename = match filename.file_name() { + Some(f) => f, + None => return None, + }; + let path = path.join(filename); + let cwd = match env::current_dir() { Ok(p) => p, Err(_) => return None, @@ -121,56 +214,71 @@ fn file_allowed(filename: &Path) -> Option<PathBuf> { } } -fn handle_rrq(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { - let mut iter = buf.iter(); +fn handle_wrq(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { + let (filename, mode) = parse_file_mode(buf)?; - let dataerr = io::Error::new(io::ErrorKind::InvalidData, "invalid data received"); + match mode.as_ref() { + "octet" => (), + _ => { + send_error(cl, 0, "Unsupported mode")?; + return Err(io::Error::new(io::ErrorKind::Other, "unsupported mode")); + } + } - let fname_len = match iter.position(|&x| x == 0) { - Some(len) => len, - None => return Err(dataerr), - }; - let fname_begin = 0; - let fname_end = fname_begin + fname_len; - let filename = match String::from_utf8(buf[fname_begin .. fname_end].to_vec()) { - Ok(fname) => fname, - Err(_) => return Err(dataerr), + let path = match file_allowed(&filename) { + Some(p) => p, + None => { + println!("Sending {} to {} failed (permission check failed).", filename.display(), cl); + send_error(cl, 2, "Permission denied")?; + return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")); + } }; - let filename = Path::new(&filename); - let mode_len = match iter.position(|&x| x == 0) { - Some(len) => len, - None => return Err(dataerr), - }; - let mode_begin = fname_end + 1; - let mode_end = mode_begin + mode_len; - let mode = match String::from_utf8(buf[mode_begin .. mode_end].to_vec()) { - Ok(m) => m.to_lowercase(), - Err(_) => return Err(dataerr), - }; + let socket = UdpSocket::bind("0.0.0.0:0")?; + socket.connect(cl)?; + socket.set_read_timeout(Some(Duration::from_secs(5)))?; + + match recv_file(&socket, &path) { + Ok(_) => println!("Received {} from {}.", path.display(), cl), + Err(err) => { + println!("Receiving {} from {} failed ({}).", path.display(), cl, err.to_string()); + send_error(cl, 0, "Receiving error")?; + return Ok(()) + } + } + + Ok(()) +} + + +fn handle_rrq(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { + let (filename, mode) = parse_file_mode(buf)?; match mode.as_ref() { "octet" => (), - _ => handle_error(cl, 0, "Unsupported mode")?, + _ => { + send_error(cl, 0, "Unsupported mode")?; + return Err(io::Error::new(io::ErrorKind::Other, "unsupported mode")); + } } let path = match file_allowed(&filename) { - Some(path) => path, + Some(p) => p, None => { - println!("Sending {} to {} failed.", filename.display(), cl); - handle_error(cl, 2, "Permission denied")?; + println!("Sending {} to {} failed (permission check failed).", filename.display(), cl); + send_error(cl, 2, "Permission denied")?; return Err(io::Error::new(io::ErrorKind::PermissionDenied, "permission denied")); } }; match send_file(&cl, &path) { Ok(_) => println!("Sent {} to {}.", path.display(), cl), - Err(_) => println!("Sending {} to {} failed.", path.display(), cl), + Err(err) => println!("Sending {} to {} failed ({}).", path.display(), cl, err.to_string()), } Ok(()) } -fn handle_error(cl: &SocketAddr, code: u16, msg: &str) -> Result<(), io::Error> { +fn send_error(cl: &SocketAddr, code: u16, msg: &str) -> Result<(), io::Error> { let socket = UdpSocket::bind("0.0.0.0:0")?; socket.connect(cl)?; @@ -182,15 +290,25 @@ fn handle_error(cl: &SocketAddr, code: u16, msg: &str) -> Result<(), io::Error> Ok(()) } -fn handle_client(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { - let opcode = u16::from_be_bytes([buf[0], buf[1]]); +fn send_ack(sock: &UdpSocket, block_nr: u16) -> Result<(), io::Error> { + let mut buf = vec![0x00, 0x04]; // opcode + buf.extend(block_nr.to_be_bytes().iter()); + + sock.send(&buf)?; - match opcode { + Ok(()) +} + +fn handle_client(cl: &SocketAddr, buf: &[u8]) -> Result<(), io::Error> { + let _opcode = match u16::from_be_bytes([buf[0], buf[1]]) { 1 /* RRQ */ => handle_rrq(&cl, &buf[2..])?, 2 /* WRQ */ => handle_wrq(&cl, &buf[2..])?, 5 /* ERROR */ => println!("Received ERROR from {}", cl), - _ => handle_error(cl, 4, "Unexpected opcode")?, - } + _ => { + send_error(cl, 4, "Unexpected opcode")?; + return Err(io::Error::new(io::ErrorKind::Other, "unexpected opcode")); + } + }; Ok(()) } |
