diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-09 16:43:16 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-09 16:43:16 -0600 |
| commit | 2ef774d4c52b9fb0ae0d1717b7a3568b76bccf3d (patch) | |
| tree | fde8c20a9333e68d7e798ec5936630375da2a1f9 /src/oauth/server.rs | |
| parent | b39a50e3ec622294cc0b6f271f1996a89f1849d6 (diff) | |
refactor: split types into separate files
Diffstat (limited to 'src/oauth/server.rs')
| -rw-r--r-- | src/oauth/server.rs | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/src/oauth/server.rs b/src/oauth/server.rs new file mode 100644 index 0000000..fdaddf6 --- /dev/null +++ b/src/oauth/server.rs @@ -0,0 +1,171 @@ +use crate::config::Config; +use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse}; +use jsonwebtoken::{DecodingKey, EncodingKey, Header, encode}; +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; +use url::Url; +use uuid::Uuid; + +pub struct OAuthServer { + config: Config, + encoding_key: EncodingKey, + decoding_key: DecodingKey, + auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>, +} + +impl OAuthServer { + pub fn new(config: &Config) -> Self { + Self { + encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()), + decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()), + auth_codes: std::sync::Mutex::new(HashMap::new()), + config: config.clone(), + } + } + + pub fn get_jwks(&self) -> String { + serde_json::json!({ + "keys": [] + }) + .to_string() + } + + pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> { + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let redirect_uri = params + .get("redirect_uri") + .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; + + let response_type = params + .get("response_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + + if response_type != "code" { + return Err(self.error_response( + "unsupported_response_type", + "Only code response type supported", + )); + } + + let code = Uuid::new_v4().to_string(); + let expires_at = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 600; + + let auth_code = AuthCode { + client_id: client_id.clone(), + redirect_uri: redirect_uri.clone(), + scope: params.get("scope").cloned(), + expires_at, + user_id: "test_user".to_string(), + }; + + { + let mut codes = self.auth_codes.lock().unwrap(); + codes.insert(code.clone(), auth_code); + } + + let mut redirect_url = Url::parse(redirect_uri) + .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; + + redirect_url.query_pairs_mut().append_pair("code", &code); + + if let Some(state) = params.get("state") { + redirect_url.query_pairs_mut().append_pair("state", state); + } + + Ok(redirect_url.to_string()) + } + + pub fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> { + let grant_type = params + .get("grant_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; + + if grant_type != "authorization_code" { + return Err(self.error_response( + "unsupported_grant_type", + "Only authorization_code grant type supported", + )); + } + + let code = params + .get("code") + .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; + + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let auth_code = { + let mut codes = self.auth_codes.lock().unwrap(); + codes.remove(code).ok_or_else(|| { + self.error_response("invalid_grant", "Invalid or expired authorization code") + })? + }; + + if auth_code.client_id != *client_id { + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + if now > auth_code.expires_at { + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + + let access_token = + self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: None, + scope: auth_code.scope, + }; + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + fn generate_access_token( + &self, + user_id: &str, + client_id: &str, + scope: &Option<String>, + ) -> Result<String, String> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let claims = Claims { + sub: user_id.to_string(), + iss: self.config.issuer_url.clone(), + aud: client_id.to_string(), + exp: now + 3600, + iat: now, + scope: scope.clone(), + }; + + encode(&Header::default(), &claims, &self.encoding_key) + .map_err(|_| self.error_response("server_error", "Failed to generate token")) + } + + fn error_response(&self, error: &str, description: &str) -> String { + let error_resp = ErrorResponse { + error: error.to_string(), + error_description: Some(description.to_string()), + }; + serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) + } +}
\ No newline at end of file |
