diff --git a/Cargo.toml b/Cargo.toml index 0413c04..5465714 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" [dependencies] actix-web = "1.0.*" actix-files = "*" +actix-service = "*" actix-identity = "*" lootalot-db = { version = "0.1", path = "./lootalot_db" } dotenv = "*" diff --git a/src/server.rs b/src/server.rs index 1463c34..6816a8c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,11 +1,13 @@ use actix_cors::Cors; use actix_files as fs; -use actix_identity::{CookieIdentityPolicy, Identity, IdentityService}; +use actix_identity::{CookieIdentityPolicy, Identity, IdentityService, RequestIdentity}; use actix_web::{ + dev::{ServiceRequest, ServiceResponse}, http::{header, StatusCode}, - middleware, web, App, Either, Error, HttpResponse, HttpServer, + middleware, web, App, Error, HttpResponse, HttpServer, }; -use futures::Future; +use actix_service::{Service, Transform}; +use futures::{Future, future::{ok, Either, FutureResult}}; use std::env; use crate::api; @@ -38,9 +40,60 @@ fn db_call( fn restricted_to_group(id: i32, params: (AppPool, api::ApiActions)) -> MaybeForbidden { if id != 0 { - Either::B(HttpResponse::Forbidden().finish()) + actix_web::Either::B(HttpResponse::Forbidden().finish()) } else { - Either::A(Box::new(db_call(params.0, params.1))) + actix_web::Either::A(Box::new(db_call(params.0, params.1))) + } +} + + +struct RestrictedAccess; + +impl Transform for RestrictedAccess +where + S: Service, Error = Error>, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = RestrictedAccessMiddleware; + type Future = FutureResult; + + fn new_transform(&self, service: S) -> Self::Future { + ok(RestrictedAccessMiddleware { service }) + } +} + +struct RestrictedAccessMiddleware { + service: S +} + +impl Service for RestrictedAccessMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Either>; + + fn poll_ready(&mut self) -> futures::Poll<(), Self::Error> { + self.service.poll_ready() + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + let is_logged_in = req.get_identity().is_some(); + + if is_logged_in { + Either::A(self.service.call(req)) + } else { + Either::B(ok(req.into_response( + HttpResponse::Forbidden().finish().into_body() + ))) + } } } @@ -48,6 +101,7 @@ fn configure_api(config: &mut web::ServiceConfig) { use api::ApiActions as Q; config.service( web::scope("/api") + .wrap(RestrictedAccess) .service( web::scope("/players") .service(