|
3 | 3 | import traceback |
4 | 4 |
|
5 | 5 | import numpy as np |
| 6 | +import orjson |
6 | 7 | import pandas as pd |
7 | 8 | from fastapi import APIRouter, HTTPException |
8 | 9 | from fastapi.responses import StreamingResponse |
@@ -147,56 +148,52 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que |
147 | 148 | llm_service.run_task_async() |
148 | 149 | except Exception as e: |
149 | 150 | traceback.print_exc() |
150 | | - raise HTTPException( |
151 | | - status_code=500, |
152 | | - detail=str(e) |
153 | | - ) |
| 151 | + |
| 152 | + def _err(_e: Exception): |
| 153 | + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' |
| 154 | + |
| 155 | + return StreamingResponse(_err(e), media_type="text/event-stream") |
154 | 156 |
|
155 | 157 | return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") |
156 | 158 |
|
157 | 159 |
|
158 | 160 | @router.post("/record/{chat_record_id}/{action_type}") |
159 | 161 | async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str, |
160 | 162 | current_assistant: CurrentAssistant): |
161 | | - if action_type != 'analysis' and action_type != 'predict': |
162 | | - raise HTTPException( |
163 | | - status_code=404, |
164 | | - detail="Not Found" |
165 | | - ) |
166 | | - record: ChatRecord | None = None |
167 | | - |
168 | | - stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type, |
169 | | - ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where( |
170 | | - and_(ChatRecord.id == chat_record_id)) |
171 | | - result = session.execute(stmt) |
172 | | - for r in result: |
173 | | - record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, |
174 | | - engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, chart=r.chart, |
175 | | - data=r.data) |
176 | | - |
177 | | - if not record: |
178 | | - raise HTTPException( |
179 | | - status_code=400, |
180 | | - detail=f"Chat record with id {chat_record_id} not found" |
181 | | - ) |
| 163 | + try: |
| 164 | + if action_type != 'analysis' and action_type != 'predict': |
| 165 | + raise Exception(f"Type {action_type} Not Found") |
| 166 | + record: ChatRecord | None = None |
| 167 | + |
| 168 | + stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, |
| 169 | + ChatRecord.engine_type, |
| 170 | + ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where( |
| 171 | + and_(ChatRecord.id == chat_record_id)) |
| 172 | + result = session.execute(stmt) |
| 173 | + for r in result: |
| 174 | + record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, |
| 175 | + engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, |
| 176 | + chart=r.chart, |
| 177 | + data=r.data) |
182 | 178 |
|
183 | | - if not record.chart: |
184 | | - raise HTTPException( |
185 | | - status_code=500, |
186 | | - detail=f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it" |
187 | | - ) |
| 179 | + if not record: |
| 180 | + raise Exception(f"Chat record with id {chat_record_id} not found") |
188 | 181 |
|
189 | | - request_question = ChatQuestion(chat_id=record.chat_id, question=record.question) |
| 182 | + if not record.chart: |
| 183 | + raise Exception( |
| 184 | + f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it") |
| 185 | + |
| 186 | + request_question = ChatQuestion(chat_id=record.chat_id, question=record.question) |
190 | 187 |
|
191 | | - try: |
192 | 188 | llm_service = await LLMService.create(current_user, request_question, current_assistant) |
193 | 189 | llm_service.run_analysis_or_predict_task_async(action_type, record) |
194 | 190 | except Exception as e: |
195 | 191 | traceback.print_exc() |
196 | | - raise HTTPException( |
197 | | - status_code=500, |
198 | | - detail=str(e) |
199 | | - ) |
| 192 | + |
| 193 | + def _err(_e: Exception): |
| 194 | + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' |
| 195 | + |
| 196 | + return StreamingResponse(_err(e), media_type="text/event-stream") |
200 | 197 |
|
201 | 198 | return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") |
202 | 199 |
|
|
0 commit comments