diff --git a/snowflake-api/README.md b/snowflake-api/README.md index b08f911..7d6a22a 100644 --- a/snowflake-api/README.md +++ b/snowflake-api/README.md @@ -44,21 +44,24 @@ snowflake-api = "0.7.0" Check [examples](./examples) for working programs using the library. - ```rust use anyhow::Result; -use snowflake_api::{QueryResult, SnowflakeApi}; +use snowflake_api::{QueryResult, AuthArgs, PasswordArgs, AuthType, SnowflakeApi, SnowflakeApiBuilder}; async fn run_query(sql: &str) -> Result { - let mut api = SnowflakeApi::with_password_auth( + + let auth = AuthArgs::new( "ACCOUNT_IDENTIFIER", Some("WAREHOUSE"), Some("DATABASE"), Some("SCHEMA"), "USERNAME", Some("ROLE"), - "PASSWORD", - )?; + AuthType::Password(PasswordArgs { password: "password".to_string() }) + ); + + let mut api: SnowflakeApi = SnowflakeApiBuilder::new(auth) + .build()?; let res = api.exec(sql).await?; Ok(res) @@ -68,7 +71,7 @@ async fn run_query(sql: &str) -> Result { Or using environment variables: ```rust - use anyhow::Result; +use anyhow::Result; use snowflake_api::{QueryResult, SnowflakeApi}; async fn run_query(sql: &str) -> Result { diff --git a/snowflake-api/examples/run_sql.rs b/snowflake-api/examples/run_sql.rs index 18ec8a9..e9edf38 100644 --- a/snowflake-api/examples/run_sql.rs +++ b/snowflake-api/examples/run_sql.rs @@ -5,7 +5,9 @@ use arrow::util::pretty::pretty_format_batches; use clap::Parser; use std::fs; -use snowflake_api::{QueryResult, SnowflakeApi}; +use snowflake_api::{ + AuthArgs, CertificateArgs, PasswordArgs, QueryResult, SnowflakeApi, SnowflakeApiBuilder, +}; #[derive(clap::ValueEnum, Clone, Debug)] enum Output { @@ -67,25 +69,31 @@ async fn main() -> Result<()> { let mut api = match (&args.private_key, &args.password) { (Some(pkey), None) => { let pem = fs::read_to_string(pkey)?; - SnowflakeApi::with_certificate_auth( - &args.account_identifier, - args.warehouse.as_deref(), - args.database.as_deref(), - args.schema.as_deref(), - &args.username, - args.role.as_deref(), - &pem, - )? + SnowflakeApiBuilder::new(AuthArgs { + account_identifier: args.account_identifier, + warehouse: args.warehouse, + database: args.database, + schema: args.schema, + username: args.username, + role: args.role, + auth_type: snowflake_api::AuthType::Certificate(CertificateArgs { + private_key_pem: pem, + }), + }) + .build()? } - (None, Some(pwd)) => SnowflakeApi::with_password_auth( - &args.account_identifier, - args.warehouse.as_deref(), - args.database.as_deref(), - args.schema.as_deref(), - &args.username, - args.role.as_deref(), - pwd, - )?, + (None, Some(pwd)) => SnowflakeApiBuilder::new(AuthArgs { + account_identifier: args.account_identifier, + warehouse: args.warehouse, + database: args.database, + schema: args.schema, + username: args.username, + role: args.role, + auth_type: snowflake_api::AuthType::Password(PasswordArgs { + password: pwd.to_string(), + }), + }) + .build()?, _ => { panic!("Either private key path or password must be set") } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 1fa7b36..db7947d 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -28,7 +28,7 @@ use reqwest_middleware::ClientWithMiddleware; use thiserror::Error; use responses::ExecResponse; -use session::{AuthError, Session}; +use session::{AuthError, Session, SessionBuilder}; use crate::connection::QueryType; use crate::connection::{Connection, ConnectionError}; @@ -193,11 +193,33 @@ pub struct AuthArgs { } impl AuthArgs { + pub fn new( + account_identifier: &str, + warehouse: Option<&str>, + database: Option<&str>, + schema: Option<&str>, + username: &str, + role: Option<&str>, + auth_type: AuthType, + ) -> Self { + Self { + account_identifier: account_identifier.to_string(), + warehouse: warehouse.map(str::to_string), + database: database.map(str::to_string), + schema: schema.map(str::to_string), + username: username.to_string(), + role: role.map(str::to_string), + auth_type, + } + } + pub fn from_env() -> Result { let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") { Ok(AuthType::Password(PasswordArgs { password })) } else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") { Ok(AuthType::Certificate(CertificateArgs { private_key_pem })) + } else if let Ok(token) = std::env::var("SNOWFLAKE_OAUTH_TOKEN") { + Ok(AuthType::OAuth(OAuthArgs { token })) } else { Err(MissingEnvArgument( "SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(), @@ -221,6 +243,7 @@ impl AuthArgs { pub enum AuthType { Password(PasswordArgs), Certificate(CertificateArgs), + OAuth(OAuthArgs), } pub struct PasswordArgs { @@ -231,6 +254,10 @@ pub struct CertificateArgs { pub private_key_pem: String, } +pub struct OAuthArgs { + pub token: String, +} + #[must_use] pub struct SnowflakeApiBuilder { pub auth: AuthArgs, @@ -253,27 +280,20 @@ impl SnowflakeApiBuilder { None => Arc::new(Connection::new()?), }; + let session = SessionBuilder::new(&self.auth.account_identifier, &self.auth.username) + .warehouse(self.auth.warehouse.as_deref()) + .database(self.auth.database.as_deref()) + .schema(self.auth.schema.as_deref()) + .role(self.auth.role.as_deref()); + let session = match self.auth.auth_type { - AuthType::Password(args) => Session::password_auth( - Arc::clone(&connection), - &self.auth.account_identifier, - self.auth.warehouse.as_deref(), - self.auth.database.as_deref(), - self.auth.schema.as_deref(), - &self.auth.username, - self.auth.role.as_deref(), - &args.password, - ), - AuthType::Certificate(args) => Session::cert_auth( - Arc::clone(&connection), - &self.auth.account_identifier, - self.auth.warehouse.as_deref(), - self.auth.database.as_deref(), - self.auth.schema.as_deref(), - &self.auth.username, - self.auth.role.as_deref(), - &args.private_key_pem, - ), + AuthType::Password(args) => { + session.build_password(Arc::clone(&connection), &args.password) + } + AuthType::Certificate(args) => { + session.build_cert(Arc::clone(&connection), &args.private_key_pem) + } + AuthType::OAuth(args) => session.build_oauth(Arc::clone(&connection), &args.token), }; let account_identifier = self.auth.account_identifier.to_uppercase(); @@ -302,67 +322,6 @@ impl SnowflakeApi { account_identifier, } } - /// Initialize object with password auth. Authentication happens on the first request. - pub fn with_password_auth( - account_identifier: &str, - warehouse: Option<&str>, - database: Option<&str>, - schema: Option<&str>, - username: &str, - role: Option<&str>, - password: &str, - ) -> Result { - let connection = Arc::new(Connection::new()?); - - let session = Session::password_auth( - Arc::clone(&connection), - account_identifier, - warehouse, - database, - schema, - username, - role, - password, - ); - - let account_identifier = account_identifier.to_uppercase(); - Ok(Self::new( - Arc::clone(&connection), - session, - account_identifier, - )) - } - - /// Initialize object with private certificate auth. Authentication happens on the first request. - pub fn with_certificate_auth( - account_identifier: &str, - warehouse: Option<&str>, - database: Option<&str>, - schema: Option<&str>, - username: &str, - role: Option<&str>, - private_key_pem: &str, - ) -> Result { - let connection = Arc::new(Connection::new()?); - - let session = Session::cert_auth( - Arc::clone(&connection), - account_identifier, - warehouse, - database, - schema, - username, - role, - private_key_pem, - ); - - let account_identifier = account_identifier.to_uppercase(); - Ok(Self::new( - Arc::clone(&connection), - session, - account_identifier, - )) - } pub fn from_env() -> Result { SnowflakeApiBuilder::new(AuthArgs::from_env()?).build() diff --git a/snowflake-api/src/requests.rs b/snowflake-api/src/requests.rs index 77b0434..ace75d3 100644 --- a/snowflake-api/src/requests.rs +++ b/snowflake-api/src/requests.rs @@ -15,6 +15,7 @@ pub struct LoginRequest { } pub type PasswordLoginRequest = LoginRequest; +pub type OAuthLoginRequest = LoginRequest; #[cfg(feature = "cert-auth")] pub type CertLoginRequest = LoginRequest; @@ -62,6 +63,15 @@ pub struct CertRequestData { pub token: String, } +#[derive(Serialize, Debug)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub struct OAuthRequestData { + #[serde(flatten)] + pub login_request_common: LoginRequestCommon, + pub authenticator: String, + pub token: String, +} + #[derive(Serialize, Debug)] #[serde(rename_all = "camelCase")] pub struct RenewSessionRequest { diff --git a/snowflake-api/src/session.rs b/snowflake-api/src/session.rs index 90acaaf..2549069 100644 --- a/snowflake-api/src/session.rs +++ b/snowflake-api/src/session.rs @@ -11,8 +11,8 @@ use crate::connection::{Connection, QueryType}; #[cfg(feature = "cert-auth")] use crate::requests::{CertLoginRequest, CertRequestData}; use crate::requests::{ - ClientEnvironment, LoginRequest, LoginRequestCommon, PasswordLoginRequest, PasswordRequestData, - RenewSessionRequest, SessionParameters, + ClientEnvironment, LoginRequest, LoginRequestCommon, OAuthLoginRequest, OAuthRequestData, + PasswordLoginRequest, PasswordRequestData, RenewSessionRequest, SessionParameters, }; use crate::responses::AuthResponse; @@ -50,6 +50,9 @@ pub enum AuthError { #[error("Enable the cert-auth feature to use certificate authentication")] CertAuthNotEnabled, + + #[error("The authentication type has not been set yet on the connection!")] + AuthTypeUnset, } #[derive(Debug)] @@ -103,14 +106,15 @@ impl AuthToken { } enum AuthType { - Certificate, - Password, + Certificate(String), + Password(String), + OAuth(String), } /// Requests, caches, and renews authentication tokens. /// Tokens are given as response to creating new session in Snowflake. Session persists /// the configuration state and temporary objects (tables, procedures, etc). -// todo: split warehouse-database-schema and username-role-key into its own structs + // todo: close session after object is dropped pub struct Session { connection: Arc, @@ -119,92 +123,110 @@ pub struct Session { auth_type: AuthType, account_identifier: String, - warehouse: Option, - database: Option, - schema: Option, - username: String, role: Option, - // This is not used with the certificate auth crate - #[allow(dead_code)] - private_key_pem: Option, - password: Option, -} -// todo: make builder -impl Session { - /// Authenticate using private certificate and JWT - // fixme: add builder or introduce structs - #[allow(clippy::too_many_arguments)] - pub fn cert_auth( - connection: Arc, - account_identifier: &str, - warehouse: Option<&str>, - database: Option<&str>, - schema: Option<&str>, - username: &str, - role: Option<&str>, - private_key_pem: &str, - ) -> Self { - // uppercase everything as this is the convention - let account_identifier = account_identifier.to_uppercase(); + object_details: SessionObjectDetails, +} - let database = database.map(str::to_uppercase); - let schema = schema.map(str::to_uppercase); +#[must_use] +pub struct SessionBuilder { + account_identifier: String, + object_details: SessionObjectDetails, + username: String, + role: Option, +} +impl SessionBuilder { + pub fn new(account_identifier: &str, username: &str) -> Self { let username = username.to_uppercase(); - let role = role.map(str::to_uppercase); - let private_key_pem = Some(private_key_pem.to_string()); Self { - connection, - auth_tokens: Mutex::new(None), - auth_type: AuthType::Certificate, - private_key_pem, - account_identifier, - warehouse: warehouse.map(str::to_uppercase), - database, + account_identifier: account_identifier.to_string(), + object_details: SessionObjectDetails::default(), username, - role, - schema, - password: None, + role: None, } } - /// Authenticate using password - // fixme: add builder or introduce structs - #[allow(clippy::too_many_arguments)] - pub fn password_auth( + pub fn warehouse(mut self, warehouse: Option<&str>) -> Self { + self.object_details.warehouse = warehouse.map(str::to_string); + self + } + + pub fn database(mut self, database: Option<&str>) -> Self { + self.object_details.database = database.map(str::to_string); + self + } + + pub fn schema(mut self, schema: Option<&str>) -> Self { + self.object_details.schema = schema.map(str::to_string); + self + } + + pub fn role(mut self, role: Option<&str>) -> Self { + self.role = role.map(str::to_string); + self + } + + pub fn build_oauth(&self, connection: Arc, oauth_access_token: &str) -> Session { + Session::new( + connection, + &self.account_identifier, + AuthType::OAuth(oauth_access_token.to_string()), + self.object_details.clone(), + &self.username, + self.role.as_deref(), + ) + } + + pub fn build_password(&self, connection: Arc, password: &str) -> Session { + Session::new( + connection, + &self.account_identifier, + AuthType::Password(password.to_string()), + self.object_details.clone(), + &self.username, + self.role.as_deref(), + ) + } + + pub fn build_cert(&self, connection: Arc, private_key_pem: &str) -> Session { + Session::new( + connection, + &self.account_identifier, + AuthType::Certificate(private_key_pem.to_string()), + self.object_details.clone(), + &self.username, + self.role.as_deref(), + ) + } +} + +#[derive(Debug, Default, Clone)] +struct SessionObjectDetails { + warehouse: Option, + database: Option, + schema: Option, +} + +impl Session { + fn new( connection: Arc, account_identifier: &str, - warehouse: Option<&str>, - database: Option<&str>, - schema: Option<&str>, + auth_type: AuthType, + object_details: SessionObjectDetails, username: &str, role: Option<&str>, - password: &str, ) -> Self { - let account_identifier = account_identifier.to_uppercase(); - - let database = database.map(str::to_uppercase); - let schema = schema.map(str::to_uppercase); - - let username = username.to_uppercase(); - let password = Some(password.to_string()); - let role = role.map(str::to_uppercase); - Self { connection, auth_tokens: Mutex::new(None), - auth_type: AuthType::Password, - account_identifier, - warehouse: warehouse.map(str::to_uppercase), - database, - username, - role, - password, - schema, - private_key_pem: None, + auth_type, + account_identifier: account_identifier.to_string(), + username: username.to_string(), + role: role.map(str::to_string), + object_details, } } @@ -217,18 +239,22 @@ impl Session { .is_some_and(|at| at.master_token.is_expired()) { // Create new session if tokens are absent or can not be exchange - let tokens = match self.auth_type { - AuthType::Certificate => { + let tokens = match &self.auth_type { + AuthType::Certificate(pem) => { log::info!("Starting session with certificate authentication"); if cfg!(feature = "cert-auth") { - self.create(self.cert_request_body()?).await + self.create(self.cert_request_body(pem)?).await } else { Err(AuthError::MissingCertificate)? } } - AuthType::Password => { + AuthType::Password(pwd) => { log::info!("Starting session with password authentication"); - self.create(self.passwd_request_body()?).await + self.create(self.passwd_request_body(pwd)).await + } + AuthType::OAuth(token) => { + log::info!("Starting session with oauth authentication"); + self.create(self.oauth_request_body(token)).await } }?; *auth_tokens = Some(tokens); @@ -277,12 +303,9 @@ impl Session { } #[cfg(feature = "cert-auth")] - fn cert_request_body(&self) -> Result { + fn cert_request_body(&self, private_key_pem: &str) -> Result { let full_identifier = format!("{}.{}", &self.account_identifier, &self.username); - let private_key_pem = self - .private_key_pem - .as_ref() - .ok_or(AuthError::MissingCertificate)?; + let jwt_token = generate_jwt_token(private_key_pem, &full_identifier)?; Ok(CertLoginRequest { @@ -294,15 +317,23 @@ impl Session { }) } - fn passwd_request_body(&self) -> Result { - let password = self.password.as_ref().ok_or(AuthError::MissingPassword)?; - - Ok(PasswordLoginRequest { + fn passwd_request_body(&self, password: &str) -> PasswordLoginRequest { + PasswordLoginRequest { data: PasswordRequestData { login_request_common: self.login_request_common(), password: password.to_string(), }, - }) + } + } + + fn oauth_request_body(&self, oauth_access_token: &str) -> OAuthLoginRequest { + OAuthLoginRequest { + data: OAuthRequestData { + login_request_common: self.login_request_common(), + authenticator: "OAUTH".to_string(), + token: oauth_access_token.to_string(), + }, + } } /// Start new session, all the Snowflake temporary objects will be scoped towards it, @@ -312,15 +343,15 @@ impl Session { body: LoginRequest, ) -> Result { let mut get_params = Vec::new(); - if let Some(warehouse) = &self.warehouse { + if let Some(warehouse) = &self.object_details.warehouse { get_params.push(("warehouse", warehouse.as_str())); } - if let Some(database) = &self.database { + if let Some(database) = &self.object_details.database { get_params.push(("databaseName", database.as_str())); } - if let Some(schema) = &self.schema { + if let Some(schema) = &self.object_details.schema { get_params.push(("schemaName", schema.as_str())); }