aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs201
1 files changed, 201 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..c1a79d0
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,201 @@
+use chrono::offset::Local;
+use futures_util::future::try_join;
+use http::uri::Uri;
+use hyper::service::{make_service_fn, service_fn};
+use hyper::upgrade::Upgraded;
+use hyper::{Body, Client, Method, Request, Response, Server};
+use regex::Regex;
+use serde::Deserialize;
+use serde_dhall::StaticType;
+use std::convert::Infallible;
+use std::net::SocketAddr;
+use structopt::StructOpt;
+use tokio::net::TcpStream;
+
+#[macro_use]
+extern crate log;
+
+type HttpClient = Client<hyper::client::HttpConnector>;
+
+#[derive(Debug, StructOpt)]
+#[structopt()]
+struct Opt {
+ #[structopt(short, long)]
+ config: String,
+}
+
+#[derive(Debug, Deserialize, StaticType, Clone)]
+struct Access {
+ from: String,
+ to: String,
+ hosts: Vec<String>,
+}
+
+#[tokio::main]
+async fn main() {
+ env_logger::from_env(env_logger::Env::default().default_filter_or("info")).init();
+
+ let opt = Opt::from_args();
+
+ match serde_dhall::from_file(opt.config)
+ .static_type_annotation()
+ .parse::<Vec<Access>>()
+ {
+ Ok(config) => {
+ let addr = SocketAddr::from(([127, 0, 0, 1], 8100));
+ let client = HttpClient::new();
+
+ let make_service = make_service_fn(move |_| {
+ let config = config.clone();
+ let client = client.clone();
+ async move {
+ Ok::<_, Infallible>(service_fn(move |req| {
+ proxy(config.clone(), client.clone(), req)
+ }))
+ }
+ });
+
+ let server = Server::bind(&addr).serve(make_service);
+
+ info!("Listening on http://{}", addr);
+
+ if let Err(e) = server.await {
+ error!("server error: {}", e);
+ }
+ }
+ Err(err) => error!("{}", err),
+ }
+}
+
+async fn proxy(
+ config: Vec<Access>,
+ client: HttpClient,
+ req: Request<Body>,
+) -> Result<Response<Body>, hyper::Error> {
+ debug!("req: {:?}", req);
+ let uri = req.uri();
+ let unauthorized_hosts = currently_unauthorized_hosts(config);
+
+ if block_uri(&unauthorized_hosts, uri) {
+ info!("Blocked: {:?}", uri);
+ Ok(Response::new(Body::empty()))
+ } else {
+ info!("Authorized: {:?}", uri);
+
+ if req.method() == Method::CONNECT {
+ if let Some(addr) = host_addr(uri) {
+ tokio::task::spawn(async move {
+ match req.into_body().on_upgrade().await {
+ Ok(upgraded) => {
+ if let Err(e) = tunnel(upgraded, &addr).await {
+ warn!("server io error: {}", e);
+ };
+ }
+ Err(e) => error!("upgrade error: {}", e),
+ }
+ });
+
+ Ok(Response::new(Body::empty()))
+ } else {
+ error!("CONNECT host is not socket addr: {:?}", uri);
+ let mut resp = Response::new(Body::from("CONNECT must be to a socket address"));
+ *resp.status_mut() = http::StatusCode::BAD_REQUEST;
+
+ Ok(resp)
+ }
+ } else {
+ client.request(req).await
+ }
+ }
+}
+
+fn host_addr(uri: &http::Uri) -> Option<String> {
+ uri.authority().map(|auth| auth.as_str().to_string())
+}
+
+// Create a TCP connection to host:port, build a tunnel between the connection and
+// the upgraded connection
+async fn tunnel(upgraded: Upgraded, addr: &str) -> std::io::Result<()> {
+ // Connect to remote server
+ let mut server = TcpStream::connect(addr).await?;
+
+ // Proxying data
+ let amounts = {
+ let (mut server_rd, mut server_wr) = server.split();
+ let (mut client_rd, mut client_wr) = tokio::io::split(upgraded);
+
+ let client_to_server = tokio::io::copy(&mut client_rd, &mut server_wr);
+ let server_to_client = tokio::io::copy(&mut server_rd, &mut client_wr);
+
+ try_join(client_to_server, server_to_client).await
+ };
+
+ match amounts {
+ Ok((from_client, from_server)) => {
+ debug!(
+ "client wrote {} bytes and received {} bytes",
+ from_client, from_server
+ );
+ }
+ Err(e) => {
+ warn!("tunnel error: {}", e);
+ }
+ };
+ Ok(())
+}
+
+fn block_uri(unauthorized_hosts: &Vec<String>, uri: &Uri) -> bool {
+ match uri.host() {
+ Some(h) => unauthorized_hosts
+ .iter()
+ .find(|unauthorized_host| h.contains(*unauthorized_host))
+ .is_some(),
+ None => false,
+ }
+}
+
+fn currently_unauthorized_hosts(config: Vec<Access>) -> Vec<String> {
+ let now = Local::now();
+ let hour = now.format("%H").to_string().parse::<i32>().unwrap();
+ let minutes = now.format("%M").to_string().parse::<i32>().unwrap();
+ let now = (hour, minutes);
+
+ config
+ .into_iter()
+ .map(|access| {
+ let from = parse_time(&access.from);
+ let to = parse_time(&access.to);
+
+ if is_before_or_eq(from, to) {
+ if is_after_or_eq(now, from) && is_before_or_eq(now, to) {
+ access.hosts.clone()
+ } else {
+ Vec::new()
+ }
+ } else {
+ if is_after_or_eq(now, from) || is_before_or_eq(now, to) {
+ access.hosts.clone()
+ } else {
+ Vec::new()
+ }
+ }
+ })
+ .collect::<Vec<Vec<String>>>()
+ .concat()
+}
+
+fn parse_time(str: &str) -> (i32, i32) {
+ let regex = Regex::new(r"(\d{2}):(\d{2})").unwrap();
+ match regex.captures(str) {
+ Some(cap) => (cap[1].parse().unwrap(), cap[2].parse().unwrap()),
+ None => panic!("Error while parsing time: {}", str),
+ }
+}
+
+fn is_before_or_eq(t1: (i32, i32), t2: (i32, i32)) -> bool {
+ t1.0 < t2.0 || t1.0 == t2.0 && t1.1 <= t2.1
+}
+
+fn is_after_or_eq(t1: (i32, i32), t2: (i32, i32)) -> bool {
+ is_before_or_eq(t2, t1)
+}