1- use std:: pin:: Pin ;
21use std:: sync:: Arc ;
3- use std:: task:: { ready, Context , Poll } ;
42
53use futures_core:: Stream ;
4+ use futures_option:: OptionExt ;
65use prost:: Message ;
76use rusqlite:: types:: ValueRef ;
7+ use tokio:: pin;
88use tokio:: sync:: mpsc;
9+ use tokio_stream:: StreamExt ;
910use tonic:: { Code , Status } ;
1011
1112use crate :: auth:: Authenticated ;
@@ -17,32 +18,76 @@ use crate::query_result_builder::{
1718} ;
1819use crate :: replication:: FrameNo ;
1920use crate :: rpc:: proxy:: rpc:: exec_req:: Request ;
20- use crate :: rpc:: proxy:: rpc:: exec_resp;
21+ use crate :: rpc:: proxy:: rpc:: exec_resp:: { self , Response } ;
2122
2223use super :: proxy:: rpc:: resp_step:: Step ;
2324use super :: proxy:: rpc:: { self , ExecReq , ExecResp , ProgramResp , RespStep , RowValue } ;
2425
25- pin_project_lite:: pin_project! {
26- pub struct StreamRequestHandler <S > {
27- #[ pin]
28- request_stream: S ,
29- connection: Arc <PrimaryConnection >,
30- state: State ,
31- authenticated: Authenticated ,
32- }
33- }
34-
35- impl < S > StreamRequestHandler < S > {
36- pub fn new (
37- request_stream : S ,
38- connection : PrimaryConnection ,
39- authenticated : Authenticated ,
40- ) -> Self {
41- Self {
42- request_stream,
43- connection : connection. into ( ) ,
44- state : State :: Idle ,
45- authenticated,
26+ pub fn make_proxy_stream < S > ( conn : PrimaryConnection , auth : Authenticated , request_stream : S ) -> impl Stream < Item = Result < ExecResp , Status > >
27+ where
28+ S : Stream < Item = Result < ExecReq , Status > > ,
29+ {
30+ async_stream:: stream! {
31+ let mut current_request_fut = None ;
32+ let ( snd, mut recv) = mpsc:: channel( 1 ) ;
33+ let conn = Arc :: new( conn) ;
34+ pin!( request_stream) ;
35+
36+ loop {
37+ tokio:: select! {
38+ biased;
39+ Some ( maybe_req) = request_stream. next( ) => {
40+ match maybe_req {
41+ Err ( e) => {
42+ tracing:: error!( "stream error: {e}" ) ;
43+ break
44+ }
45+ Ok ( req) => {
46+ let request_id = req. request_id;
47+ match req. request {
48+ Some ( Request :: Execute ( pgm) ) => {
49+ let Ok ( pgm) =
50+ crate :: connection:: program:: Program :: try_from( pgm. pgm. unwrap( ) ) else {
51+ yield Err ( Status :: new( Code :: InvalidArgument , "invalid program" ) ) ;
52+ break
53+ } ;
54+ let conn = conn. clone( ) ;
55+ let auth = auth. clone( ) ;
56+ let sender = snd. clone( ) ;
57+
58+ let fut = async move {
59+ let builder = StreamResponseBuilder {
60+ request_id,
61+ sender,
62+ current: None ,
63+ current_size: 0 ,
64+ } ;
65+
66+ let ret = conn. execute_program( pgm, auth, builder, None ) . await ;
67+ ( ret, request_id)
68+ } ;
69+
70+ current_request_fut. replace( Box :: pin( fut) ) ;
71+ }
72+ Some ( Request :: Describe ( _) ) => todo!( ) ,
73+ None => {
74+ yield Err ( Status :: new( Code :: InvalidArgument , "invalid request" ) ) ;
75+ break
76+ }
77+ }
78+ }
79+ }
80+ } ,
81+ Some ( res) = recv. recv( ) => {
82+ yield Ok ( res) ;
83+ } ,
84+ ( ret, request_id) = current_request_fut. current( ) , if current_request_fut. is_some( ) => {
85+ if let Err ( e) = ret {
86+ yield Ok ( ExecResp { request_id, response: Some ( Response :: Error ( e. into( ) ) ) } )
87+ }
88+ } ,
89+ else => break ,
90+ }
4691 }
4792 }
4893}
@@ -199,110 +244,3 @@ impl From<ValueRef<'_>> for RowValue {
199244 RowValue { value }
200245 }
201246}
202-
203- enum State {
204- Execute ( Pin < Box < dyn Stream < Item = ExecResp > + Send > > ) ,
205- Idle ,
206- Fused ,
207- }
208-
209- impl < S > Stream for StreamRequestHandler < S >
210- where
211- S : Stream < Item = Result < ExecReq , Status > > ,
212- {
213- type Item = Result < ExecResp , Status > ;
214-
215- fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
216- let this = self . project ( ) ;
217-
218- // we always poll from the request stream. If a new request arrive, we interupt the current
219- // one, and move to the next.
220- if let Poll :: Ready ( maybe_req) = this. request_stream . poll_next ( cx) {
221- match maybe_req {
222- Some ( Err ( e) ) => {
223- * this. state = State :: Fused ;
224- return Poll :: Ready ( Some ( Err ( e) ) )
225- }
226- Some ( Ok ( req) ) => {
227- let request_id = req. request_id ;
228- match req. request {
229- Some ( Request :: Execute ( pgm) ) => {
230- let Ok ( pgm) =
231- crate :: connection:: program:: Program :: try_from ( pgm. pgm . unwrap ( ) ) else {
232- * this. state = State :: Fused ;
233- return Poll :: Ready ( Some ( Err ( Status :: new ( Code :: InvalidArgument , "invalid program" ) ) ) ) ;
234- } ;
235- let conn = this. connection . clone ( ) ;
236- let authenticated = this. authenticated . clone ( ) ;
237-
238- let s = async_stream:: stream! {
239- let ( sender, mut receiver) = mpsc:: channel( 1 ) ;
240- let builder = StreamResponseBuilder {
241- request_id,
242- sender,
243- current: None ,
244- current_size: 0 ,
245- } ;
246- let mut fut = conn. execute_program( pgm, authenticated, builder, None ) ;
247- loop {
248- tokio:: select! {
249- res = & mut fut => {
250- // drain the receiver
251- while let Ok ( msg) = receiver. try_recv( ) {
252- yield msg;
253- }
254-
255- if let Err ( e) = res {
256- yield ExecResp {
257- request_id,
258- response: Some ( exec_resp:: Response :: Error ( e. into( ) ) )
259- }
260- }
261- break
262- }
263- msg = receiver. recv( ) => {
264- if let Some ( msg) = msg {
265- yield msg;
266- }
267- }
268- }
269- }
270- } ;
271- * this. state = State :: Execute ( Box :: pin ( s) ) ;
272- }
273- Some ( Request :: Describe ( _) ) => todo ! ( ) ,
274- None => {
275- * this. state = State :: Fused ;
276- return Poll :: Ready ( Some ( Err ( Status :: new (
277- Code :: InvalidArgument ,
278- "invalid ExecReq: missing request" ,
279- ) ) ) ) ;
280- }
281- }
282- }
283- None => {
284- * this. state = State :: Fused ;
285- return Poll :: Ready ( None )
286- }
287- }
288- }
289-
290- match this. state {
291- State :: Idle => Poll :: Pending ,
292- State :: Fused => Poll :: Ready ( None ) ,
293- State :: Execute ( stream) => {
294- let resp = ready ! ( stream. as_mut( ) . poll_next( cx) ) ;
295- match resp {
296- Some ( resp) => Poll :: Ready ( Some ( Ok ( resp) ) ) ,
297- None => {
298- // finished processing this query. Wake up immediately to prepare for the
299- // next
300- * this. state = State :: Idle ;
301- cx. waker ( ) . wake_by_ref ( ) ;
302- Poll :: Pending
303- }
304- }
305- }
306- }
307- }
308- }
0 commit comments