scufflecloud_core/middleware/
auth.rs

1use std::str::FromStr;
2use std::sync::Arc;
3
4use axum::extract::Request;
5use axum::http::{HeaderMap, HeaderName, StatusCode};
6use axum::middleware::Next;
7use axum::response::Response;
8use base64::Engine;
9use diesel::{BoolExpressionMethods, ExpressionMethods, SelectableHelper};
10use diesel_async::RunQueryDsl;
11use fred::prelude::KeysInterface;
12use hmac::Mac;
13
14use crate::CoreConfig;
15use crate::http_ext::RequestExt;
16use crate::middleware::IpAddressInfo;
17use crate::models::{UserSession, UserSessionTokenId};
18use crate::schema::user_sessions;
19
20const TOKEN_ID_HEADER: HeaderName = HeaderName::from_static("scuf-token-id");
21const TIMESTAMP_HEADER: HeaderName = HeaderName::from_static("scuf-timestamp");
22const NONCE_HEADER: HeaderName = HeaderName::from_static("scuf-nonce");
23
24const AUTHENTICATION_METHOD_HEADER: HeaderName = HeaderName::from_static("scuf-auth-method");
25const AUTHENTICATION_HMAC_HEADER: HeaderName = HeaderName::from_static("scuf-auth-hmac");
26
27#[derive(Clone, Debug)]
28pub(crate) struct ExpiredSession(pub UserSession);
29
30pub(crate) async fn auth<G: CoreConfig>(mut req: Request, next: Next) -> Result<Response, StatusCode> {
31    let global = req
32        .extensions()
33        .global::<G>()
34        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
35    let ip_info = req
36        .extensions()
37        .ip_address_info()
38        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
39
40    let (session, expired_session) = get_and_update_active_session(&global, &ip_info, req.headers()).await?;
41    if let Some(session) = session {
42        req.extensions_mut().insert(session);
43    }
44    if let Some(expired_session) = expired_session {
45        req.extensions_mut().insert(expired_session);
46    }
47
48    Ok(next.run(req).await)
49}
50
51fn get_auth_header<'a, T>(headers: &'a HeaderMap, header_name: &HeaderName) -> Result<Option<T>, StatusCode>
52where
53    T: FromStr + 'a,
54    T::Err: std::fmt::Display,
55{
56    match headers.get(header_name) {
57        Some(h) => {
58            let s = h.to_str().map_err(|e| {
59                tracing::debug!(header = %header_name, error = %e, "invalid header value");
60                StatusCode::BAD_REQUEST
61            })?;
62            Ok(Some(s.parse().map_err(|e| {
63                tracing::debug!(header = %header_name, error = %e, "failed to parse header value");
64                StatusCode::BAD_REQUEST
65            })?))
66        }
67        None => Ok(None),
68    }
69}
70
71#[derive(Debug, thiserror::Error)]
72enum AuthenticationMethodParseError {
73    #[error("unknown authentication algorithm")]
74    UnknownAlgorithm,
75    #[error("invalid header format")]
76    InvalidHeaderFormat,
77}
78
79#[derive(Debug)]
80enum AuthenticationAlgorithm {
81    HmacSha256,
82}
83
84impl FromStr for AuthenticationAlgorithm {
85    type Err = AuthenticationMethodParseError;
86
87    fn from_str(s: &str) -> Result<Self, Self::Err> {
88        match s {
89            "HMAC-SHA256" => Ok(AuthenticationAlgorithm::HmacSha256),
90            _ => Err(AuthenticationMethodParseError::UnknownAlgorithm),
91        }
92    }
93}
94
95#[derive(Debug)]
96struct AuthenticationMethod {
97    pub algorithm: AuthenticationAlgorithm,
98    pub headers: Vec<HeaderName>,
99}
100
101impl FromStr for AuthenticationMethod {
102    type Err = AuthenticationMethodParseError;
103
104    fn from_str(s: &str) -> Result<Self, Self::Err> {
105        let parts: Vec<&str> = s.splitn(2, ';').collect();
106        if parts.len() != 2 {
107            return Err(AuthenticationMethodParseError::InvalidHeaderFormat);
108        }
109
110        let algorithm: AuthenticationAlgorithm = parts[0].parse()?;
111        let headers: Vec<HeaderName> = parts[1]
112            .split(',')
113            .map(|h| HeaderName::from_str(h.trim()).map_err(|_| AuthenticationMethodParseError::InvalidHeaderFormat))
114            .collect::<Result<_, _>>()?;
115
116        Ok(AuthenticationMethod { algorithm, headers })
117    }
118}
119
120#[derive(thiserror::Error, Debug)]
121enum NonceParseError {
122    #[error("failed to decode: {0}")]
123    Base64(#[from] base64::DecodeError),
124    #[error("invalid nonce length {0}, must be 32 bytes")]
125    InvalidLength(usize),
126}
127
128#[derive(Debug)]
129struct Nonce(Vec<u8>);
130
131impl FromStr for Nonce {
132    type Err = NonceParseError;
133
134    fn from_str(s: &str) -> Result<Self, Self::Err> {
135        let bytes = base64::prelude::BASE64_STANDARD.decode(s)?;
136        if bytes.len() != 32 {
137            return Err(NonceParseError::InvalidLength(bytes.len()));
138        }
139        Ok(Nonce(bytes))
140    }
141}
142
143#[derive(Debug)]
144struct AuthenticationHmac(Vec<u8>);
145
146impl FromStr for AuthenticationHmac {
147    type Err = base64::DecodeError;
148
149    fn from_str(s: &str) -> Result<Self, Self::Err> {
150        let bytes = base64::prelude::BASE64_STANDARD.decode(s)?;
151        Ok(AuthenticationHmac(bytes))
152    }
153}
154
155async fn get_and_update_active_session<G: CoreConfig>(
156    global: &Arc<G>,
157    ip_info: &IpAddressInfo,
158    headers: &HeaderMap,
159) -> Result<(Option<UserSession>, Option<ExpiredSession>), StatusCode> {
160    let Some(session_token_id) = get_auth_header::<UserSessionTokenId>(headers, &TOKEN_ID_HEADER)? else {
161        return Ok((None, None));
162    };
163    let Some(timestamp) =
164        get_auth_header::<u64>(headers, &TIMESTAMP_HEADER)?.and_then(|t| chrono::DateTime::from_timestamp_millis(t as i64))
165    else {
166        return Ok((None, None));
167    };
168    let Some(nonce) = get_auth_header::<Nonce>(headers, &NONCE_HEADER)? else {
169        return Ok((None, None));
170    };
171
172    let Some(auth_method) = get_auth_header::<AuthenticationMethod>(headers, &AUTHENTICATION_METHOD_HEADER)? else {
173        return Ok((None, None));
174    };
175    let Some(auth_hmac) = get_auth_header::<AuthenticationHmac>(headers, &AUTHENTICATION_HMAC_HEADER)? else {
176        return Ok((None, None));
177    };
178
179    if timestamp > chrono::Utc::now() || timestamp < chrono::Utc::now() - global.max_request_lifetime() {
180        tracing::debug!(timestamp = %timestamp, "invalid request timestamp");
181        return Err(StatusCode::UNAUTHORIZED);
182    }
183
184    if !auth_method.headers.contains(&TOKEN_ID_HEADER)
185        || !auth_method.headers.contains(&TIMESTAMP_HEADER)
186        || !auth_method.headers.contains(&NONCE_HEADER)
187    {
188        tracing::debug!("missing required headers in authentication method");
189        return Err(StatusCode::BAD_REQUEST);
190    }
191
192    let mut db = global.db().await.map_err(|e| {
193        tracing::error!(error = %e, "failed to connect to database");
194        StatusCode::INTERNAL_SERVER_ERROR
195    })?;
196
197    let Some(session) = diesel::update(user_sessions::dsl::user_sessions)
198        .set((
199            user_sessions::dsl::last_ip.eq(ip_info.to_network()),
200            user_sessions::dsl::last_used_at.eq(chrono::Utc::now()),
201        ))
202        .filter(
203            user_sessions::dsl::token_id
204                .eq(session_token_id)
205                .and(user_sessions::dsl::token.is_not_null())
206                .and(user_sessions::dsl::expires_at.gt(chrono::Utc::now())),
207        )
208        .returning(UserSession::as_select())
209        .get_results::<UserSession>(&mut db)
210        .await
211        .map_err(|e| {
212            tracing::error!(error = %e, "failed to update user session");
213            StatusCode::INTERNAL_SERVER_ERROR
214        })?
215        .into_iter()
216        .next()
217    else {
218        tracing::debug!(token_id = %session_token_id, "no active session found");
219        return Err(StatusCode::UNAUTHORIZED);
220    };
221
222    let token = session.token.as_ref().expect("known to be not null due to filter");
223
224    // Verify HMAC
225    match auth_method.algorithm {
226        AuthenticationAlgorithm::HmacSha256 => {
227            let mut mac = hmac::Hmac::<sha2::Sha256>::new_from_slice(token).map_err(|e| {
228                tracing::error!(error = %e, "failed to create HMAC instance");
229                StatusCode::INTERNAL_SERVER_ERROR
230            })?;
231
232            for header_name in &auth_method.headers {
233                if let Some(value) = headers.get(header_name) {
234                    mac.update(value.as_bytes());
235                } else {
236                    tracing::debug!(header = %header_name, "missing header");
237                    return Err(StatusCode::BAD_REQUEST);
238                }
239            }
240
241            mac.verify_slice(&auth_hmac.0).map_err(|e| {
242                tracing::debug!(error = %e, "HMAC verification failed");
243                StatusCode::UNAUTHORIZED
244            })?;
245        }
246    }
247
248    let mut key = "nonces:".as_bytes().to_vec();
249    key.extend_from_slice(&nonce.0);
250    let value: Option<bool> = global
251        .redis()
252        .set(
253            key.as_slice(),
254            true,
255            Some(fred::types::Expiration::PX(global.max_request_lifetime().num_milliseconds())),
256            Some(fred::types::SetOptions::NX),
257            true,
258        )
259        .await
260        .map_err(|e| {
261            tracing::error!(error = %e, "failed to set nonce in redis");
262            StatusCode::INTERNAL_SERVER_ERROR
263        })?;
264
265    if value.is_some() {
266        tracing::debug!("replayed nonce detected");
267        return Err(StatusCode::UNAUTHORIZED);
268    }
269
270    if session.token_expires_at.is_some_and(|t| t <= chrono::Utc::now()) {
271        return Ok((None, Some(ExpiredSession(session))));
272    }
273
274    Ok((Some(session), None))
275}