Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

Commit d29081a

Browse files
committed
simplify stream proxy implementation
1 parent b7f1ef7 commit d29081a

File tree

4 files changed

+89
-142
lines changed

4 files changed

+89
-142
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqld/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ rustls-pemfile = "1.0.3"
7171
rustls = "0.21.7"
7272
async-stream = "0.3.5"
7373
libsql = { git = "https://github.com/tursodatabase/libsql.git", rev = "bea8863", optional = true }
74+
futures-option = "0.2.0"
7475

7576
[dev-dependencies]
7677
proptest = "1.0.0"

sqld/src/rpc/proxy.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use std::collections::HashMap;
2+
use std::pin::Pin;
23
use std::str::FromStr;
34
use std::sync::Arc;
45

56
use async_lock::{RwLock, RwLockUpgradableReadGuard};
7+
use futures_core::Stream;
68
use rusqlite::types::ValueRef;
79
use uuid::Uuid;
810

@@ -15,14 +17,14 @@ use crate::query_result_builder::{
1517
Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError,
1618
};
1719
use crate::replication::FrameNo;
20+
use crate::rpc::streaming_exec::make_proxy_stream;
1821

1922
use self::rpc::proxy_server::Proxy;
2023
use self::rpc::query_result::RowResult;
2124
use self::rpc::{
2225
describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, ExecReq,
23-
ExecuteResults, QueryResult, ResultRows, Row,
26+
ExecuteResults, QueryResult, ResultRows, Row, ExecResp,
2427
};
25-
use super::streaming_exec::StreamRequestHandler;
2628
use super::NAMESPACE_DOESNT_EXIST;
2729

2830
pub mod rpc {
@@ -467,20 +469,18 @@ pub async fn garbage_collect(clients: &mut HashMap<Uuid, Arc<PrimaryConnection>>
467469

468470
#[tonic::async_trait]
469471
impl Proxy for ProxyService {
470-
type StreamExecStream = StreamRequestHandler<tonic::Streaming<ExecReq>>;
472+
type StreamExecStream = Pin<Box<dyn Stream<Item = Result<ExecResp, tonic::Status>> + Send>>;
471473

472474
async fn stream_exec(
473475
&self,
474476
req: tonic::Request<tonic::Streaming<ExecReq>>,
475477
) -> Result<tonic::Response<Self::StreamExecStream>, tonic::Status> {
476-
dbg!();
477-
let authenticated = if let Some(auth) = &self.auth {
478+
let auth= if let Some(auth) = &self.auth {
478479
auth.authenticate_grpc(&req, self.disable_namespaces)?
479480
} else {
480481
Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)?
481482
};
482483

483-
dbg!();
484484
let namespace = super::extract_namespace(self.disable_namespaces, &req)?;
485485
let (connection_maker, _new_frame_notifier) = self
486486
.namespaces
@@ -498,13 +498,11 @@ impl Proxy for ProxyService {
498498
}
499499
})?;
500500

501-
dbg!();
502-
let connection = connection_maker.create().await.unwrap();
501+
let conn = connection_maker.create().await.unwrap();
503502

504-
dbg!();
505-
let handler = StreamRequestHandler::new(req.into_inner(), connection, authenticated);
503+
let stream = make_proxy_stream(conn, auth, req.into_inner());
506504

507-
Ok(tonic::Response::new(handler))
505+
Ok(tonic::Response::new(Box::pin(stream)))
508506
}
509507

510508
async fn execute(

sqld/src/rpc/streaming_exec.rs

Lines changed: 69 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
use std::pin::Pin;
21
use std::sync::Arc;
3-
use std::task::{ready, Context, Poll};
42

53
use futures_core::Stream;
4+
use futures_option::OptionExt;
65
use prost::Message;
76
use rusqlite::types::ValueRef;
7+
use tokio::pin;
88
use tokio::sync::mpsc;
9+
use tokio_stream::StreamExt;
910
use tonic::{Code, Status};
1011

1112
use crate::auth::Authenticated;
@@ -17,32 +18,76 @@ use crate::query_result_builder::{
1718
};
1819
use crate::replication::FrameNo;
1920
use crate::rpc::proxy::rpc::exec_req::Request;
20-
use crate::rpc::proxy::rpc::exec_resp;
21+
use crate::rpc::proxy::rpc::exec_resp::{self, Response};
2122

2223
use super::proxy::rpc::resp_step::Step;
2324
use super::proxy::rpc::{self, ExecReq, ExecResp, ProgramResp, RespStep, RowValue};
2425

25-
pin_project_lite::pin_project! {
26-
pub struct StreamRequestHandler<S> {
27-
#[pin]
28-
request_stream: S,
29-
connection: Arc<PrimaryConnection>,
30-
state: State,
31-
authenticated: Authenticated,
32-
}
33-
}
34-
35-
impl<S> StreamRequestHandler<S> {
36-
pub fn new(
37-
request_stream: S,
38-
connection: PrimaryConnection,
39-
authenticated: Authenticated,
40-
) -> Self {
41-
Self {
42-
request_stream,
43-
connection: connection.into(),
44-
state: State::Idle,
45-
authenticated,
26+
pub fn make_proxy_stream<S>(conn: PrimaryConnection, auth: Authenticated, request_stream: S) -> impl Stream<Item = Result<ExecResp, Status>>
27+
where
28+
S: Stream<Item = Result<ExecReq, Status>>,
29+
{
30+
async_stream::stream! {
31+
let mut current_request_fut = None;
32+
let (snd, mut recv) = mpsc::channel(1);
33+
let conn = Arc::new(conn);
34+
pin!(request_stream);
35+
36+
loop {
37+
tokio::select! {
38+
biased;
39+
Some(maybe_req) = request_stream.next() => {
40+
match maybe_req {
41+
Err(e) => {
42+
tracing::error!("stream error: {e}");
43+
break
44+
}
45+
Ok(req) => {
46+
let request_id = req.request_id;
47+
match req.request {
48+
Some(Request::Execute(pgm)) => {
49+
let Ok(pgm) =
50+
crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else {
51+
yield Err(Status::new(Code::InvalidArgument, "invalid program"));
52+
break
53+
};
54+
let conn = conn.clone();
55+
let auth = auth.clone();
56+
let sender = snd.clone();
57+
58+
let fut = async move {
59+
let builder = StreamResponseBuilder {
60+
request_id,
61+
sender,
62+
current: None,
63+
current_size: 0,
64+
};
65+
66+
let ret = conn.execute_program(pgm, auth, builder, None).await;
67+
(ret, request_id)
68+
};
69+
70+
current_request_fut.replace(Box::pin(fut));
71+
}
72+
Some(Request::Describe(_)) => todo!(),
73+
None => {
74+
yield Err(Status::new(Code::InvalidArgument, "invalid request"));
75+
break
76+
}
77+
}
78+
}
79+
}
80+
},
81+
Some(res) = recv.recv() => {
82+
yield Ok(res);
83+
},
84+
(ret, request_id) = current_request_fut.current(), if current_request_fut.is_some() => {
85+
if let Err(e) = ret {
86+
yield Ok(ExecResp { request_id, response: Some(Response::Error(e.into())) })
87+
}
88+
},
89+
else => break,
90+
}
4691
}
4792
}
4893
}
@@ -199,110 +244,3 @@ impl From<ValueRef<'_>> for RowValue {
199244
RowValue { value }
200245
}
201246
}
202-
203-
enum State {
204-
Execute(Pin<Box<dyn Stream<Item = ExecResp> + Send>>),
205-
Idle,
206-
Fused,
207-
}
208-
209-
impl<S> Stream for StreamRequestHandler<S>
210-
where
211-
S: Stream<Item = Result<ExecReq, Status>>,
212-
{
213-
type Item = Result<ExecResp, Status>;
214-
215-
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
216-
let this = self.project();
217-
218-
// we always poll from the request stream. If a new request arrive, we interupt the current
219-
// one, and move to the next.
220-
if let Poll::Ready(maybe_req) = this.request_stream.poll_next(cx) {
221-
match maybe_req {
222-
Some(Err(e)) => {
223-
*this.state = State::Fused;
224-
return Poll::Ready(Some(Err(e)))
225-
}
226-
Some(Ok(req)) => {
227-
let request_id = req.request_id;
228-
match req.request {
229-
Some(Request::Execute(pgm)) => {
230-
let Ok(pgm) =
231-
crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else {
232-
*this.state = State::Fused;
233-
return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid program"))));
234-
};
235-
let conn = this.connection.clone();
236-
let authenticated = this.authenticated.clone();
237-
238-
let s = async_stream::stream! {
239-
let (sender, mut receiver) = mpsc::channel(1);
240-
let builder = StreamResponseBuilder {
241-
request_id,
242-
sender,
243-
current: None,
244-
current_size: 0,
245-
};
246-
let mut fut = conn.execute_program(pgm, authenticated, builder, None);
247-
loop {
248-
tokio::select! {
249-
res = &mut fut => {
250-
// drain the receiver
251-
while let Ok(msg) = receiver.try_recv() {
252-
yield msg;
253-
}
254-
255-
if let Err(e) = res {
256-
yield ExecResp {
257-
request_id,
258-
response: Some(exec_resp::Response::Error(e.into()))
259-
}
260-
}
261-
break
262-
}
263-
msg = receiver.recv() => {
264-
if let Some(msg) = msg {
265-
yield msg;
266-
}
267-
}
268-
}
269-
}
270-
};
271-
*this.state = State::Execute(Box::pin(s));
272-
}
273-
Some(Request::Describe(_)) => todo!(),
274-
None => {
275-
*this.state = State::Fused;
276-
return Poll::Ready(Some(Err(Status::new(
277-
Code::InvalidArgument,
278-
"invalid ExecReq: missing request",
279-
))));
280-
}
281-
}
282-
}
283-
None => {
284-
*this.state = State::Fused;
285-
return Poll::Ready(None)
286-
}
287-
}
288-
}
289-
290-
match this.state {
291-
State::Idle => Poll::Pending,
292-
State::Fused => Poll::Ready(None),
293-
State::Execute(stream) => {
294-
let resp = ready!(stream.as_mut().poll_next(cx));
295-
match resp {
296-
Some(resp) => Poll::Ready(Some(Ok(resp))),
297-
None => {
298-
// finished processing this query. Wake up immediately to prepare for the
299-
// next
300-
*this.state = State::Idle;
301-
cx.waker().wake_by_ref();
302-
Poll::Pending
303-
}
304-
}
305-
}
306-
}
307-
}
308-
}

0 commit comments

Comments
 (0)