diff options
Diffstat (limited to 'src/oauth')
| -rw-r--r-- | src/oauth/mod.rs | 5 | ||||
| -rw-r--r-- | src/oauth/server.rs | 171 | ||||
| -rw-r--r-- | src/oauth/types.rs | 39 |
3 files changed, 215 insertions, 0 deletions
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs new file mode 100644 index 0000000..d6717c2 --- /dev/null +++ b/src/oauth/mod.rs @@ -0,0 +1,5 @@ +pub mod server; +pub mod types; + +pub use server::OAuthServer; +pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse};
\ No newline at end of file diff --git a/src/oauth/server.rs b/src/oauth/server.rs new file mode 100644 index 0000000..fdaddf6 --- /dev/null +++ b/src/oauth/server.rs @@ -0,0 +1,171 @@ +use crate::config::Config; +use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse}; +use jsonwebtoken::{DecodingKey, EncodingKey, Header, encode}; +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; +use url::Url; +use uuid::Uuid; + +pub struct OAuthServer { + config: Config, + encoding_key: EncodingKey, + decoding_key: DecodingKey, + auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>, +} + +impl OAuthServer { + pub fn new(config: &Config) -> Self { + Self { + encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()), + decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()), + auth_codes: std::sync::Mutex::new(HashMap::new()), + config: config.clone(), + } + } + + pub fn get_jwks(&self) -> String { + serde_json::json!({ + "keys": [] + }) + .to_string() + } + + pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> { + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let redirect_uri = params + .get("redirect_uri") + .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; + + let response_type = params + .get("response_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + + if response_type != "code" { + return Err(self.error_response( + "unsupported_response_type", + "Only code response type supported", + )); + } + + let code = Uuid::new_v4().to_string(); + let expires_at = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 600; + + let auth_code = AuthCode { + client_id: client_id.clone(), + redirect_uri: redirect_uri.clone(), + scope: params.get("scope").cloned(), + expires_at, + user_id: "test_user".to_string(), + }; + + { + let mut codes = self.auth_codes.lock().unwrap(); + codes.insert(code.clone(), auth_code); + } + + let mut redirect_url = Url::parse(redirect_uri) + .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; + + redirect_url.query_pairs_mut().append_pair("code", &code); + + if let Some(state) = params.get("state") { + redirect_url.query_pairs_mut().append_pair("state", state); + } + + Ok(redirect_url.to_string()) + } + + pub fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> { + let grant_type = params + .get("grant_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; + + if grant_type != "authorization_code" { + return Err(self.error_response( + "unsupported_grant_type", + "Only authorization_code grant type supported", + )); + } + + let code = params + .get("code") + .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; + + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let auth_code = { + let mut codes = self.auth_codes.lock().unwrap(); + codes.remove(code).ok_or_else(|| { + self.error_response("invalid_grant", "Invalid or expired authorization code") + })? + }; + + if auth_code.client_id != *client_id { + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + if now > auth_code.expires_at { + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + + let access_token = + self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: None, + scope: auth_code.scope, + }; + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + fn generate_access_token( + &self, + user_id: &str, + client_id: &str, + scope: &Option<String>, + ) -> Result<String, String> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let claims = Claims { + sub: user_id.to_string(), + iss: self.config.issuer_url.clone(), + aud: client_id.to_string(), + exp: now + 3600, + iat: now, + scope: scope.clone(), + }; + + encode(&Header::default(), &claims, &self.encoding_key) + .map_err(|_| self.error_response("server_error", "Failed to generate token")) + } + + fn error_response(&self, error: &str, description: &str) -> String { + let error_resp = ErrorResponse { + error: error.to_string(), + error_description: Some(description.to_string()), + }; + serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) + } +}
\ No newline at end of file diff --git a/src/oauth/types.rs b/src/oauth/types.rs new file mode 100644 index 0000000..685af02 --- /dev/null +++ b/src/oauth/types.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + pub sub: String, + pub iss: String, + pub aud: String, + pub exp: u64, + pub iat: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse { + pub error: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_description: Option<String>, +} + +#[derive(Debug, Clone)] +pub struct AuthCode { + pub client_id: String, + pub redirect_uri: String, + pub scope: Option<String>, + pub expires_at: u64, + pub user_id: String, +}
\ No newline at end of file |
