bidirectional header passing, adding question struct

This commit is contained in:
Andy Pack 2024-01-29 18:37:25 +00:00
parent 4051665b6d
commit 985adbae68
Signed by: sarsoo
GPG Key ID: A55BA3536A5E0ED7
10 changed files with 407 additions and 27 deletions

View File

@ -27,3 +27,8 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: build
- name: Cargo Test
uses: actions-rs/cargo@v1
with:
command: test

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
/target
**/*.log
.idea

67
Cargo.lock generated
View File

@ -130,6 +130,16 @@ name = "dnstplib"
version = "0.1.0"
dependencies = [
"log",
"url",
]
[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
[[package]]
@ -138,6 +148,16 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "idna"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "itoa"
version = "1.0.10"
@ -165,6 +185,12 @@ dependencies = [
"libc",
]
[[package]]
name = "percent-encoding"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "powerfmt"
version = "0.2.0"
@ -277,12 +303,53 @@ dependencies = [
"time-core",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "unicode-bidi"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unicode-normalization"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921"
dependencies = [
"tinyvec",
]
[[package]]
name = "url"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
]
[[package]]
name = "utf8parse"
version = "0.2.1"

View File

@ -1,19 +1,20 @@
use std::fs::File;
use std::net::{SocketAddr, UdpSocket};
use std::net::SocketAddr;
use std::thread;
use std::time::Duration;
use clap::Parser;
use log::{error, info, LevelFilter};
use log::{info, LevelFilter};
use simplelog::*;
use dnstplib::dns_socket::DNSSocket;
use dnstplib::raw_request::NetworkMessage;
use dnstplib::request_processor::RequestProcesor;
use dnstplib::response_processor::ResponseProcesor;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Addresses to send requests
#[arg(short, long)]
address: String,
}
fn main() {
@ -51,7 +52,7 @@ fn main() {
tx_channel.send(Box::from(NetworkMessage {
buffer: Box::from(send_buf),
peer: "127.0.0.1:5000".parse().unwrap()
peer: args.address.parse().unwrap()
}));
thread::sleep(Duration::from_secs(1));

View File

@ -7,3 +7,4 @@ edition = "2021"
[dependencies]
log = "0.4.20"
url = "2.5.0"

View File

@ -1,7 +1,14 @@
use std::convert::TryFrom;
pub const HEADER_SIZE: usize = 12;
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
pub enum Direction {
Request, Response
Request = 0,
Response = 1
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
pub enum Opcode {
Query = 0,
RQuery = 1,
@ -9,15 +16,57 @@ pub enum Opcode {
Reserved = 3
}
pub enum ResponseCode {
NoError = 0,
FormatSpecError = 1,
ServerFailure = 2,
NameError = 3,
RequestTypeUnsupported = 4,
NotExecuted = 5
impl TryFrom<u16> for Opcode {
type Error = ();
fn try_from(v: u16) -> Result<Self, Self::Error> {
match v {
x if x == Opcode::Query as u16 => Ok(Opcode::Query),
x if x == Opcode::RQuery as u16 => Ok(Opcode::RQuery),
x if x == Opcode::Status as u16 => Ok(Opcode::Status),
x if x == Opcode::Reserved as u16 => Ok(Opcode::Reserved),
_ => Err(()),
}
}
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
pub enum ResponseCode {
NoError = 0,
FormatError = 1,
ServerFailure = 2,
NameError = 3,
NotImplemented = 4,
Refused = 5,
YXDomain = 6,
YXRRSet = 7,
NXRRSet = 8,
NotAuth = 9,
NotZone = 10
}
impl TryFrom<u16> for ResponseCode {
type Error = ();
fn try_from(v: u16) -> Result<Self, Self::Error> {
match v {
x if x == ResponseCode::NoError as u16 => Ok(ResponseCode::NoError),
x if x == ResponseCode::FormatError as u16 => Ok(ResponseCode::FormatError),
x if x == ResponseCode::ServerFailure as u16 => Ok(ResponseCode::ServerFailure),
x if x == ResponseCode::NameError as u16 => Ok(ResponseCode::NameError),
x if x == ResponseCode::NotImplemented as u16 => Ok(ResponseCode::NotImplemented),
x if x == ResponseCode::Refused as u16 => Ok(ResponseCode::Refused),
x if x == ResponseCode::YXDomain as u16 => Ok(ResponseCode::YXDomain),
x if x == ResponseCode::YXRRSet as u16 => Ok(ResponseCode::YXRRSet),
x if x == ResponseCode::NXRRSet as u16 => Ok(ResponseCode::NXRRSet),
x if x == ResponseCode::NotAuth as u16 => Ok(ResponseCode::NotAuth),
x if x == ResponseCode::NotZone as u16 => Ok(ResponseCode::NotZone),
_ => Err(()),
}
}
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct DNSHeader {
pub id: u16,
pub direction: Direction,
@ -26,6 +75,7 @@ pub struct DNSHeader {
pub truncation: bool,
pub recursion_desired: bool,
pub recursion_available: bool,
pub valid_zeroes: bool,
pub response: ResponseCode,
pub question_count: u16,
pub answer_record_count: u16,

66
dnstp/src/dns_question.rs Normal file
View File

@ -0,0 +1,66 @@
use url::form_urlencoded;
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
enum QType {
A = 1,
NS = 2,
CNAME = 5,
SOA = 6,
WKS = 11,
PTR = 12,
HINFO = 13,
MINFO = 14,
MX = 15,
TXT = 16,
RP = 17,
AAAA = 28,
SRV = 33
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
enum QClass {
Internet = 1,
Chaos = 3,
Hesiod = 4,
}
struct DNSQuestion {
qname: String,
qtype: QType,
qclass: QClass
}
impl DNSQuestion {
pub fn new(qname: String, qtype: QType, qclass: QClass) -> DNSQuestion
{
DNSQuestion {
qname,
qtype,
qclass
}
}
pub fn to_bytes(&self) -> Vec<u8>
{
let mut ret: Vec<u8> = vec!();
for part in self.qname.split(".")
{
let encoded_string: String = form_urlencoded::byte_serialize(part.as_bytes()).collect();
let count = encoded_string.len();
ret.push(count as u8);
ret.reserve(count);
for x in encoded_string.into_bytes() {
ret.push(x);
};
}
ret.push(0);
ret.push(self.qtype as u8);
ret.push(self.qclass as u8);
ret
}
}

View File

@ -1,11 +1,13 @@
use std::net::{SocketAddr, UdpSocket};
use std::ptr::read;
use std::thread;
use std::thread::{JoinHandle};
use log::{error, info};
use log::{debug, error, info};
use std::str;
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender, TryRecvError};
use crate::dns_header::HEADER_SIZE;
use crate::raw_request::{NetworkMessage, NetworkMessagePtr};
pub struct DNSSocket {
@ -80,16 +82,18 @@ impl DNSSocket {
let res = s.recv_from(&mut (*buf));
match res {
Ok((_, peer)) => {
Ok((read_count, peer)) => {
let res_str = str::from_utf8(&(*buf)).unwrap();
info!("received [{}] from [{}]", res_str, peer);
match message_sender.send(Box::new(NetworkMessage {
buffer: buf,
peer
}))
{
Ok(_) => {}
Err(_) => {}
if read_count > HEADER_SIZE {
message_sender.send(Box::new(NetworkMessage {
buffer: buf,
peer
}));
}
else {
debug!("skipping processing message from [{}], message isn't longer than standard header", peer);
}
}
Err(_) => {}
@ -129,7 +133,9 @@ impl DNSSocket {
for m in &msg_rx {
info!("sending [{}] to [{}]", str::from_utf8(&(*(*m).buffer)).unwrap(), (*m).peer);
s.send_to(&(*m.buffer), m.peer);
if let Err(e) = s.send_to(&(*m.buffer), m.peer){
error!("error sending response to [{}], {}", m.peer, e);
}
}
cancelled = match rx.try_recv() {

View File

@ -4,3 +4,4 @@ mod dns_header;
pub mod request_processor;
pub mod response_processor;
pub mod raw_request;
mod dns_question;

View File

@ -0,0 +1,182 @@
use crate::dns_header::{Direction, DNSHeader, Opcode, ResponseCode};
use crate::dns_header::Direction::Response;
fn two_byte_extraction(buffer: &[u8], idx: usize) -> u16
{
((buffer[idx] as u16) << 8) | buffer[idx + 1] as u16
}
fn two_byte_split(num: u16) -> (u8, u8)
{
((num >> 8) as u8, (num & 0b0000000011111111) as u8)
}
const ID_START: usize = 0;
const FLAGS_START: usize = 2;
const DIRECTION_SHIFT: usize = 15;
const OPCODE_SHIFT: usize = 11;
const AUTHORITATIVE_SHIFT: usize = 10;
const TRUNCATION_SHIFT: usize = 9;
const RECURSION_DESIRED_SHIFT: usize = 8;
const RECURSION_AVAILABLE_SHIFT: usize = 7;
const ZEROES_SHIFT: usize = 4;
const QUESTION_COUNT_START: usize = 4;
const ANSWER_RECORD_COUNT_START: usize = 6;
const AUTHORITY_RECORD_COUNT_START: usize = 8;
const ADDITIONAL_RECORD_COUNT_START: usize = 10;
pub fn parse_header(header: &[u8; 12]) -> Result<DNSHeader, ()>
{
let id = two_byte_extraction(header, ID_START);
let flags = two_byte_extraction(header, FLAGS_START);
let direction = if flags & (0b1 << DIRECTION_SHIFT) == 0 {Direction::Request} else { Direction::Response };
let opcode: Result<Opcode, ()> = ((flags & (0b1111 << OPCODE_SHIFT)) >> OPCODE_SHIFT).try_into();
if let Err(e) = opcode {
return Err(e);
}
let authoritative = (flags & (0b1 << AUTHORITATIVE_SHIFT)) != 0;
let truncation = (flags & (0b1 << TRUNCATION_SHIFT)) != 0;
let recursion_desired = (flags & (0b1 << RECURSION_DESIRED_SHIFT)) != 0;
let recursion_available = (flags & (0b1 << RECURSION_AVAILABLE_SHIFT)) != 0;
let zeroes = (flags & (0b111 << ZEROES_SHIFT)) == 0;
let response: Result<ResponseCode, ()> = (flags & 0b1111).try_into();
if let Err(e) = response
{
return Err(e);
}
let question_count = two_byte_extraction(header, QUESTION_COUNT_START);
let answer_record_count = two_byte_extraction(header, ANSWER_RECORD_COUNT_START);
let authority_record_count = two_byte_extraction(header, AUTHORITY_RECORD_COUNT_START);
let additional_record_count = two_byte_extraction(header, ADDITIONAL_RECORD_COUNT_START);
Ok(DNSHeader {
id,
direction,
opcode: opcode.unwrap(),
authoritative,
truncation,
recursion_desired,
recursion_available,
valid_zeroes: zeroes,
response: response.unwrap(),
question_count,
answer_record_count,
authority_record_count,
additional_record_count
})
}
fn apply_split_bytes(buffer: &mut [u8], value: u16, index: usize)
{
let val = two_byte_split(value);
buffer[index] = val.0;
buffer[index + 1] = val.1;
}
pub fn parse_header_to_bytes(header: &DNSHeader) -> [u8; 12]
{
let mut header_bytes: [u8; 12] = [0; 12];
apply_split_bytes(&mut header_bytes, header.id, ID_START);
let mut flags: u16 = 0;
if header.direction == Response {
flags |= 0b1 << DIRECTION_SHIFT;
}
flags |= (header.opcode as u16) << OPCODE_SHIFT;
flags |= (header.authoritative as u16) << AUTHORITATIVE_SHIFT;
flags |= (header.truncation as u16) << TRUNCATION_SHIFT;
flags |= (header.recursion_desired as u16) << RECURSION_DESIRED_SHIFT;
flags |= (header.recursion_available as u16) << RECURSION_AVAILABLE_SHIFT;
flags |= header.response as u16;
apply_split_bytes(&mut header_bytes, flags, FLAGS_START);
apply_split_bytes(&mut header_bytes, header.question_count, QUESTION_COUNT_START);
apply_split_bytes(&mut header_bytes, header.answer_record_count, ANSWER_RECORD_COUNT_START);
apply_split_bytes(&mut header_bytes, header.authority_record_count, AUTHORITY_RECORD_COUNT_START);
apply_split_bytes(&mut header_bytes, header.additional_record_count, ADDITIONAL_RECORD_COUNT_START);
header_bytes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn two_byte_extraction_test() {
let buffer: [u8; 12] = core::array::from_fn(|i| (i + 1) as u8);
let value = two_byte_extraction(&buffer, 0);
assert_eq!(value, 258);
let value = two_byte_extraction(&buffer, 2);
assert_eq!(value, 772);
}
#[test]
fn two_byte_split_test() {
let (val1, val2) = two_byte_split(258);
assert_eq!(val1, 1);
assert_eq!(val2, 2);
let (val1, val2) = two_byte_split(772);
assert_eq!(val1, 3);
assert_eq!(val2, 4);
}
#[test]
fn both_ways_test() {
let header = DNSHeader {
id: 100,
direction: Direction::Response,
opcode: Opcode::Query,
authoritative: true,
truncation: false,
recursion_desired: true,
recursion_available: false,
valid_zeroes: true,
response: ResponseCode::NoError,
question_count: 1,
answer_record_count: 2,
authority_record_count: 3,
additional_record_count: 4
};
let parsed_bytes = parse_header_to_bytes(&header);
let header_again = parse_header(&parsed_bytes).unwrap();
assert_eq!(header.id, header_again.id);
assert_eq!(header.direction, header_again.direction);
assert_eq!(header.opcode, header_again.opcode);
assert_eq!(header.authoritative, header_again.authoritative);
assert_eq!(header.truncation, header_again.truncation);
assert_eq!(header.recursion_desired, header_again.recursion_desired);
assert_eq!(header.recursion_available, header_again.recursion_available);
assert_eq!(header.valid_zeroes, header_again.valid_zeroes);
assert_eq!(header.response, header_again.response);
assert_eq!(header.question_count, header_again.question_count);
assert_eq!(header.answer_record_count, header_again.answer_record_count);
assert_eq!(header.authority_record_count, header_again.authority_record_count);
assert_eq!(header.additional_record_count, header_again.additional_record_count);
}
}