Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion snowflake-api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep

- [x] Single statements [example](./examples/run_sql.rs)
- [ ] Multiple statements
- [ ] Async requests (is it needed if whole library is async?)
- [ ] Async requests (to allow for long-running queries and multi-statement)
- [x] Query results in [Arrow](https://arrow.apache.org/)
- [x] Chunked query results
- [x] Password, certificate, env auth
Expand Down
32 changes: 27 additions & 5 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![doc(
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
)]
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic)]
Expand All @@ -13,6 +13,7 @@ clippy::future_not_send, // This one seems like something we should eventually f
clippy::missing_panics_doc
)]

use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::io;
use std::path::Path;
Expand All @@ -31,10 +32,10 @@ use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use crate::connection::{Connection, ConnectionError};
use responses::ExecResponse;
use session::{AuthError, Session};

use crate::connection::{Connection, ConnectionError};
use crate::connection::QueryType;
use crate::requests::ExecRequest;
use crate::responses::{
Expand Down Expand Up @@ -395,8 +396,9 @@ impl SnowflakeApi {
log::debug!("Got PUT response: {:?}", resp);

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.put(pg).await,
// put-get by design is async, and isn't a query response
ExecResponse::MultiStatementQuery(_) | ExecResponse::AsyncQuery(_) | ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
Expand Down Expand Up @@ -479,6 +481,8 @@ impl SnowflakeApi {
// processable response
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::AsyncQuery(_) => Err(SnowflakeApiError::Unimplemented("Async queries, ie the ones returning a handle to query id".to_owned())),
ExecResponse::MultiStatementQuery(_) => Err(SnowflakeApiError::Unimplemented("Multi-statement queries are not implemented as they require polling on the user-side".to_owned())),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
Expand All @@ -505,7 +509,7 @@ impl SnowflakeApi {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;
.await?;

// fixme: should base64 chunk go first?
// fixme: if response is chunked is it both base64 + chunks or just chunks?
Expand All @@ -529,12 +533,17 @@ impl SnowflakeApi {
log::debug!("Executing: {}", sql_text);

let parts = self.session.get_token().await?;
// todo: move clientStartTime, requestId, request_guid from request parameters to request body into this map
let mut parameters = HashMap::new();
parameters.insert("MULTI_STATEMENT_COUNT".to_owned(), Self::count_statements(sql_text).to_string());

let body = ExecRequest {
sql_text: sql_text.to_string(),
async_exec: false,
sequence_id: parts.sequence_id,
is_internal: false,
describe_only: false,
parameters,
};

let resp = self
Expand All @@ -550,4 +559,17 @@ impl SnowflakeApi {

Ok(resp)
}

fn count_statements(sql_text: &str) -> usize {
// fixme: find better way to count split
let count = sql_text.chars().filter(|&c| c == ';').count();

if count == 0 {
// non-terminated single query is still a single query
1
} else {
// what if there are multiple queries, but the last one is not ;-terminated?
count
}
}
}
10 changes: 7 additions & 3 deletions snowflake-api/src/requests.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::collections::HashMap;

use serde::Serialize;

#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ExecRequest {
pub sql_text: String,
pub async_exec: bool,
pub sequence_id: u64,
pub is_internal: bool,
pub async_exec: bool, // fixme: doesn't exist in .NET
pub sequence_id: u64, // fixme: doesn't exist in .NET
pub is_internal: bool, // fixme: doesn't exist in .NET
pub describe_only: bool, // fixme: optional in GO, required in .NET
pub parameters: HashMap<String, String>,
}

#[derive(Serialize, Debug)]
Expand Down
54 changes: 40 additions & 14 deletions snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use serde::Deserialize;
#[derive(Deserialize, Debug)]
#[serde(untagged)]
pub enum ExecResponse {
Query(QueryExecResponse),
PutGet(PutGetExecResponse),
Error(ExecErrorResponse),
AsyncQuery(AsyncQueryResponse),
MultiStatementQuery(MultiStatementQueryResponse),
Query(QueryExecResponse),
// before-last since has intersecting fields
Error(ExecErrorResponse), // last since essentially catch-all
}

// todo: add close session response, which should be just empty?
Expand All @@ -32,6 +35,8 @@ pub struct BaseRestResponse<D> {
pub data: D,
}

pub type MultiStatementQueryResponse = BaseRestResponse<MultiStatementQueryResponseData>;
pub type AsyncQueryResponse = BaseRestResponse<AsyncQueryResponseData>;
pub type PutGetExecResponse = BaseRestResponse<PutGetResponseData>;
pub type QueryExecResponse = BaseRestResponse<QueryExecResponseData>;
pub type ExecErrorResponse = BaseRestResponse<ExecErrorResponseData>;
Expand Down Expand Up @@ -124,15 +129,21 @@ pub struct QueryExecResponseData {
// is base64-encoded Arrow IPC payload
pub rowset_base64: Option<String>,
pub total: i64,
pub returned: i64, // unused in .NET
pub query_id: String, // unused in .NET
pub returned: i64,
// unused in .NET
pub query_id: String,
// unused in .NET
pub database_provider: Option<String>,
pub final_database_name: Option<String>, // unused in .NET
pub final_database_name: Option<String>,
// unused in .NET
pub final_schema_name: Option<String>,
pub final_warehouse_name: Option<String>, // unused in .NET
pub final_role_name: String, // unused in .NET
pub final_warehouse_name: Option<String>,
// unused in .NET
pub final_role_name: String,
// unused in .NET
// only present on SELECT queries
pub number_of_binds: Option<i32>, // unused in .NET
pub number_of_binds: Option<i32>,
// unused in .NET
// todo: deserialize into enum
pub statement_type_id: i64,
pub version: i64,
Expand All @@ -143,12 +154,6 @@ pub struct QueryExecResponseData {
pub qrmk: Option<String>,
#[serde(default)] // chunks are present
pub chunk_headers: HashMap<String, String>,
// when async query is run (ping pong request?)
pub get_result_url: Option<String>,
// multi-statement response, comma-separated
pub result_ids: Option<String>,
// `progressDesc`, and `queryAbortAfterSecs` are not used but exist in .NET
// `sendResultTime`, `queryResultFormat`, `queryContext` also exist
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -304,3 +309,24 @@ pub struct PutGetEncryptionMaterial {
pub query_id: String,
pub smk_id: i64,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct AsyncQueryResponseData {
pub query_id: String,
pub get_result_url: String,
pub query_aborts_after_secs: i64,
pub progress_desc: Option<String>,
}

// fixme: this is not correct, but useful
// since the response will include more fields from [`QueryExecResponseData`]
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct MultiStatementQueryResponseData {
pub query_id: String,
// comma-separated
pub result_ids: String,
// comma-separated
pub result_types: String,
}