106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
|
|
"""
|
|||
|
|
分析管理API
|
|||
|
|
"""
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
|
|||
|
|
from app.database import get_db
|
|||
|
|
from app.models import Report, AnalysisStatus
|
|||
|
|
from app.schemas import MessageResponse, AnalysisRequest
|
|||
|
|
from app.services.ai_analyzer import ai_analyzer
|
|||
|
|
from app.services.pdf_extractor import pdf_extractor
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/api/analysis", tags=["AI分析"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def run_analysis_task(report_id: int):
|
|||
|
|
"""后台分析任务"""
|
|||
|
|
async with AsyncSession(get_db) as db: # 注意:这里需要正确的获取session方式,但在后台任务中比较复杂
|
|||
|
|
# 简化处理:在后台任务中重新创建session
|
|||
|
|
from app.database import AsyncSessionLocal
|
|||
|
|
async with AsyncSessionLocal() as session:
|
|||
|
|
stmt = select(Report).where(Report.id == report_id)
|
|||
|
|
result = await session.execute(stmt)
|
|||
|
|
report = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if report:
|
|||
|
|
# 1. 确保已提取
|
|||
|
|
if not report.is_extracted:
|
|||
|
|
await pdf_extractor.extract_and_save(session, report)
|
|||
|
|
|
|||
|
|
# 2. 执行分析
|
|||
|
|
await ai_analyzer.analyze_report(session, report, force=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/run", response_model=MessageResponse)
|
|||
|
|
async def run_analysis(
|
|||
|
|
request: AnalysisRequest,
|
|||
|
|
background_tasks: BackgroundTasks,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""触发AI分析"""
|
|||
|
|
try:
|
|||
|
|
stmt = select(Report).where(Report.id == request.report_id)
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
report = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not report:
|
|||
|
|
raise HTTPException(status_code=404, detail="报告不存在")
|
|||
|
|
|
|||
|
|
if not report.is_downloaded:
|
|||
|
|
raise HTTPException(status_code=400, detail="报告尚未下载,无法分析")
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
report.analysis_status = AnalysisStatus.PENDING.value
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
# 启动后台任务
|
|||
|
|
import asyncio
|
|||
|
|
asyncio.create_task(run_analysis_with_new_session(request.report_id, request.force))
|
|||
|
|
|
|||
|
|
return MessageResponse(success=True, message="已开始分析任务")
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
import traceback
|
|||
|
|
error_msg = f"启动分析失败: {str(e)}\n{traceback.format_exc()}"
|
|||
|
|
print(error_msg) # 打印到控制台
|
|||
|
|
raise HTTPException(status_code=500, detail=str(e)) # 返回给前端
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def run_analysis_with_new_session(report_id: int, force: bool):
|
|||
|
|
from app.database import AsyncSessionLocal
|
|||
|
|
from app.models import Report, AnalysisStatus # 显式导入以防未加载
|
|||
|
|
|
|||
|
|
async with AsyncSessionLocal() as session:
|
|||
|
|
report = None
|
|||
|
|
try:
|
|||
|
|
stmt = select(Report).where(Report.id == report_id)
|
|||
|
|
result = await session.execute(stmt)
|
|||
|
|
report = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if report:
|
|||
|
|
# 1. 确保已提取
|
|||
|
|
if not report.is_extracted:
|
|||
|
|
await pdf_extractor.extract_and_save(session, report)
|
|||
|
|
|
|||
|
|
# 2. 执行分析
|
|||
|
|
await ai_analyzer.analyze_report(session, report, force=force)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
import traceback
|
|||
|
|
error_trace = traceback.format_exc()
|
|||
|
|
print(f"后台分析任务严重异常 (report_id={report_id}): {e}\n{error_trace}")
|
|||
|
|
|
|||
|
|
# 标记为失败,防止前端永久卡死
|
|||
|
|
if report:
|
|||
|
|
try:
|
|||
|
|
# 重新获取对象以确保 session 仍然有效或在新的 commit 中
|
|||
|
|
report.analysis_status = AnalysisStatus.FAILED.value
|
|||
|
|
await session.commit()
|
|||
|
|
except Exception as commit_err:
|
|||
|
|
print(f"标记任务失败时发生二次错误: {commit_err}")
|
|||
|
|
finally:
|
|||
|
|
await session.close()
|