91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
"""
|
|
报告管理API
|
|
"""
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, desc
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.database import get_db
|
|
from app.models import Report, ExtractedContent, AnalysisResult
|
|
from app.schemas import ReportResponse, ReportDetail, MessageResponse
|
|
|
|
router = APIRouter(prefix="/api/reports", tags=["报告管理"])
|
|
|
|
|
|
@router.get("", response_model=List[ReportResponse])
|
|
async def list_reports(
|
|
company_id: Optional[int] = None,
|
|
report_type: Optional[str] = None,
|
|
year: Optional[int] = None,
|
|
is_analyzed: Optional[bool] = None,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""获取报告列表"""
|
|
stmt = select(Report).options(selectinload(Report.company))
|
|
|
|
if company_id:
|
|
stmt = stmt.where(Report.company_id == company_id)
|
|
if report_type:
|
|
stmt = stmt.where(Report.report_type == report_type)
|
|
if year:
|
|
stmt = stmt.where(Report.report_year == year)
|
|
if is_analyzed is not None:
|
|
stmt = stmt.where(Report.is_analyzed == is_analyzed)
|
|
|
|
stmt = stmt.order_by(desc(Report.created_at))
|
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
|
|
|
result = await db.execute(stmt)
|
|
reports = result.scalars().all()
|
|
|
|
response = []
|
|
for report in reports:
|
|
report_data = ReportResponse.model_validate(report)
|
|
# 补充关联信息
|
|
report_data.company_name = report.company.company_name
|
|
report_data.stock_code = report.company.stock_code
|
|
response.append(report_data)
|
|
|
|
return response
|
|
|
|
|
|
@router.get("/{report_id}", response_model=ReportDetail)
|
|
async def get_report(report_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""获取报告详情(包含提取内容和分析结果)"""
|
|
stmt = select(Report).where(Report.id == report_id).options(
|
|
selectinload(Report.company),
|
|
selectinload(Report.extracted_contents),
|
|
selectinload(Report.analysis_results)
|
|
)
|
|
result = await db.execute(stmt)
|
|
report = result.scalar_one_or_none()
|
|
|
|
if not report:
|
|
raise HTTPException(status_code=404, detail="报告不存在")
|
|
|
|
report_data = ReportDetail.model_validate(report)
|
|
report_data.company_name = report.company.company_name
|
|
report_data.stock_code = report.company.stock_code
|
|
|
|
return report_data
|
|
|
|
|
|
@router.delete("/{report_id}", response_model=MessageResponse)
|
|
async def delete_report(report_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""删除报告"""
|
|
stmt = select(Report).where(Report.id == report_id)
|
|
result = await db.execute(stmt)
|
|
report = result.scalar_one_or_none()
|
|
|
|
if not report:
|
|
raise HTTPException(status_code=404, detail="报告不存在")
|
|
|
|
await db.delete(report)
|
|
await db.commit()
|
|
|
|
return MessageResponse(success=True, message="删除成功")
|