106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
"""
|
||
分析管理API
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
|
||
from app.database import get_db
|
||
from app.models import Report, AnalysisStatus
|
||
from app.schemas import MessageResponse, AnalysisRequest
|
||
from app.services.ai_analyzer import ai_analyzer
|
||
from app.services.pdf_extractor import pdf_extractor
|
||
|
||
router = APIRouter(prefix="/api/analysis", tags=["AI分析"])
|
||
|
||
|
||
async def run_analysis_task(report_id: int):
|
||
"""后台分析任务"""
|
||
async with AsyncSession(get_db) as db: # 注意:这里需要正确的获取session方式,但在后台任务中比较复杂
|
||
# 简化处理:在后台任务中重新创建session
|
||
from app.database import AsyncSessionLocal
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = select(Report).where(Report.id == report_id)
|
||
result = await session.execute(stmt)
|
||
report = result.scalar_one_or_none()
|
||
|
||
if report:
|
||
# 1. 确保已提取
|
||
if not report.is_extracted:
|
||
await pdf_extractor.extract_and_save(session, report)
|
||
|
||
# 2. 执行分析
|
||
await ai_analyzer.analyze_report(session, report, force=True)
|
||
|
||
|
||
@router.post("/run", response_model=MessageResponse)
|
||
async def run_analysis(
|
||
request: AnalysisRequest,
|
||
background_tasks: BackgroundTasks,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""触发AI分析"""
|
||
try:
|
||
stmt = select(Report).where(Report.id == request.report_id)
|
||
result = await db.execute(stmt)
|
||
report = result.scalar_one_or_none()
|
||
|
||
if not report:
|
||
raise HTTPException(status_code=404, detail="报告不存在")
|
||
|
||
if not report.is_downloaded:
|
||
raise HTTPException(status_code=400, detail="报告尚未下载,无法分析")
|
||
|
||
# 更新状态
|
||
report.analysis_status = AnalysisStatus.PENDING.value
|
||
await db.commit()
|
||
|
||
# 启动后台任务
|
||
import asyncio
|
||
asyncio.create_task(run_analysis_with_new_session(request.report_id, request.force))
|
||
|
||
return MessageResponse(success=True, message="已开始分析任务")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
import traceback
|
||
error_msg = f"启动分析失败: {str(e)}\n{traceback.format_exc()}"
|
||
print(error_msg) # 打印到控制台
|
||
raise HTTPException(status_code=500, detail=str(e)) # 返回给前端
|
||
|
||
|
||
async def run_analysis_with_new_session(report_id: int, force: bool):
|
||
from app.database import AsyncSessionLocal
|
||
from app.models import Report, AnalysisStatus # 显式导入以防未加载
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
report = None
|
||
try:
|
||
stmt = select(Report).where(Report.id == report_id)
|
||
result = await session.execute(stmt)
|
||
report = result.scalar_one_or_none()
|
||
|
||
if report:
|
||
# 1. 确保已提取
|
||
if not report.is_extracted:
|
||
await pdf_extractor.extract_and_save(session, report)
|
||
|
||
# 2. 执行分析
|
||
await ai_analyzer.analyze_report(session, report, force=force)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_trace = traceback.format_exc()
|
||
print(f"后台分析任务严重异常 (report_id={report_id}): {e}\n{error_trace}")
|
||
|
||
# 标记为失败,防止前端永久卡死
|
||
if report:
|
||
try:
|
||
# 重新获取对象以确保 session 仍然有效或在新的 commit 中
|
||
report.analysis_status = AnalysisStatus.FAILED.value
|
||
await session.commit()
|
||
except Exception as commit_err:
|
||
print(f"标记任务失败时发生二次错误: {commit_err}")
|
||
finally:
|
||
await session.close()
|