diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..c1a79d0 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,201 @@ +use chrono::offset::Local; +use futures_util::future::try_join; +use http::uri::Uri; +use hyper::service::{make_service_fn, service_fn}; +use hyper::upgrade::Upgraded; +use hyper::{Body, Client, Method, Request, Response, Server}; +use regex::Regex; +use serde::Deserialize; +use serde_dhall::StaticType; +use std::convert::Infallible; +use std::net::SocketAddr; +use structopt::StructOpt; +use tokio::net::TcpStream; + +#[macro_use] +extern crate log; + +type HttpClient = Client<hyper::client::HttpConnector>; + +#[derive(Debug, StructOpt)] +#[structopt()] +struct Opt { + #[structopt(short, long)] + config: String, +} + +#[derive(Debug, Deserialize, StaticType, Clone)] +struct Access { + from: String, + to: String, + hosts: Vec<String>, +} + +#[tokio::main] +async fn main() { + env_logger::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let opt = Opt::from_args(); + + match serde_dhall::from_file(opt.config) + .static_type_annotation() + .parse::<Vec<Access>>() + { + Ok(config) => { + let addr = SocketAddr::from(([127, 0, 0, 1], 8100)); + let client = HttpClient::new(); + + let make_service = make_service_fn(move |_| { + let config = config.clone(); + let client = client.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + proxy(config.clone(), client.clone(), req) + })) + } + }); + + let server = Server::bind(&addr).serve(make_service); + + info!("Listening on http://{}", addr); + + if let Err(e) = server.await { + error!("server error: {}", e); + } + } + Err(err) => error!("{}", err), + } +} + +async fn proxy( + config: Vec<Access>, + client: HttpClient, + req: Request<Body>, +) -> Result<Response<Body>, hyper::Error> { + debug!("req: {:?}", req); + let uri = req.uri(); + let unauthorized_hosts = currently_unauthorized_hosts(config); + + if block_uri(&unauthorized_hosts, uri) { + info!("Blocked: {:?}", uri); + Ok(Response::new(Body::empty())) + } else { + info!("Authorized: {:?}", uri); + + if req.method() == Method::CONNECT { + if let Some(addr) = host_addr(uri) { + tokio::task::spawn(async move { + match req.into_body().on_upgrade().await { + Ok(upgraded) => { + if let Err(e) = tunnel(upgraded, &addr).await { + warn!("server io error: {}", e); + }; + } + Err(e) => error!("upgrade error: {}", e), + } + }); + + Ok(Response::new(Body::empty())) + } else { + error!("CONNECT host is not socket addr: {:?}", uri); + let mut resp = Response::new(Body::from("CONNECT must be to a socket address")); + *resp.status_mut() = http::StatusCode::BAD_REQUEST; + + Ok(resp) + } + } else { + client.request(req).await + } + } +} + +fn host_addr(uri: &http::Uri) -> Option<String> { + uri.authority().map(|auth| auth.as_str().to_string()) +} + +// Create a TCP connection to host:port, build a tunnel between the connection and +// the upgraded connection +async fn tunnel(upgraded: Upgraded, addr: &str) -> std::io::Result<()> { + // Connect to remote server + let mut server = TcpStream::connect(addr).await?; + + // Proxying data + let amounts = { + let (mut server_rd, mut server_wr) = server.split(); + let (mut client_rd, mut client_wr) = tokio::io::split(upgraded); + + let client_to_server = tokio::io::copy(&mut client_rd, &mut server_wr); + let server_to_client = tokio::io::copy(&mut server_rd, &mut client_wr); + + try_join(client_to_server, server_to_client).await + }; + + match amounts { + Ok((from_client, from_server)) => { + debug!( + "client wrote {} bytes and received {} bytes", + from_client, from_server + ); + } + Err(e) => { + warn!("tunnel error: {}", e); + } + }; + Ok(()) +} + +fn block_uri(unauthorized_hosts: &Vec<String>, uri: &Uri) -> bool { + match uri.host() { + Some(h) => unauthorized_hosts + .iter() + .find(|unauthorized_host| h.contains(*unauthorized_host)) + .is_some(), + None => false, + } +} + +fn currently_unauthorized_hosts(config: Vec<Access>) -> Vec<String> { + let now = Local::now(); + let hour = now.format("%H").to_string().parse::<i32>().unwrap(); + let minutes = now.format("%M").to_string().parse::<i32>().unwrap(); + let now = (hour, minutes); + + config + .into_iter() + .map(|access| { + let from = parse_time(&access.from); + let to = parse_time(&access.to); + + if is_before_or_eq(from, to) { + if is_after_or_eq(now, from) && is_before_or_eq(now, to) { + access.hosts.clone() + } else { + Vec::new() + } + } else { + if is_after_or_eq(now, from) || is_before_or_eq(now, to) { + access.hosts.clone() + } else { + Vec::new() + } + } + }) + .collect::<Vec<Vec<String>>>() + .concat() +} + +fn parse_time(str: &str) -> (i32, i32) { + let regex = Regex::new(r"(\d{2}):(\d{2})").unwrap(); + match regex.captures(str) { + Some(cap) => (cap[1].parse().unwrap(), cap[2].parse().unwrap()), + None => panic!("Error while parsing time: {}", str), + } +} + +fn is_before_or_eq(t1: (i32, i32), t2: (i32, i32)) -> bool { + t1.0 < t2.0 || t1.0 == t2.0 && t1.1 <= t2.1 +} + +fn is_after_or_eq(t1: (i32, i32), t2: (i32, i32)) -> bool { + is_before_or_eq(t2, t1) +} |