Skip to content

Commit e6d8ea2

Browse files
committed
Add invalid request handling callback for websocket server
1 parent bef231d commit e6d8ea2

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/handshake/server.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ pub trait Callback: Sized {
161161
request: &Request,
162162
response: Response,
163163
) -> StdResult<Response, ErrorResponse>;
164+
165+
/// Called whenever the server read an inavlid request,
166+
/// e.g. the connection upgrade header is missing from the request,
167+
/// then server can call this function to form a valid response
168+
/// instead of just drop the connection
169+
fn on_invalid_request(self, _request: &Request, error: Error) -> Result<ErrorResponse> {
170+
Err(error)
171+
}
164172
}
165173

166174
impl<F> Callback for F
@@ -240,11 +248,21 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
240248
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
241249
}
242250

243-
let response = create_response(&result)?;
244-
let callback_result = if let Some(callback) = self.callback.take() {
245-
callback.on_request(&result, response)
246-
} else {
247-
Ok(response)
251+
let callback_result = match create_response(&result) {
252+
Ok(response) => {
253+
if let Some(callback) = self.callback.take() {
254+
callback.on_request(&result, response)
255+
} else {
256+
Ok(response)
257+
}
258+
}
259+
Err(error) => {
260+
if let Some(callback) = self.callback.take() {
261+
Err(callback.on_invalid_request(&result, error)?)
262+
} else {
263+
return Err(error);
264+
}
265+
}
248266
};
249267

250268
match callback_result {

0 commit comments

Comments
 (0)