175 lines
6.9 KiB
Python
175 lines
6.9 KiB
Python
"""
|
||
定时任务服务
|
||
"""
|
||
from datetime import datetime, timezone
|
||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||
from apscheduler.triggers.cron import CronTrigger
|
||
from apscheduler.triggers.interval import IntervalTrigger
|
||
from sqlalchemy import select
|
||
|
||
from app.config import settings
|
||
from app.database import AsyncSessionLocal
|
||
from app.utils.logger import logger
|
||
from app.models import Company, Report, TaskLog
|
||
from app.services.cninfo_crawler import cninfo_service
|
||
from app.services.pdf_extractor import pdf_extractor
|
||
from app.services.ai_analyzer import ai_analyzer
|
||
|
||
|
||
class SchedulerService:
|
||
"""定时任务服务"""
|
||
|
||
def __init__(self):
|
||
self.scheduler = AsyncIOScheduler()
|
||
self.is_running = False
|
||
self.last_run_time = None
|
||
self.last_run_status = None
|
||
|
||
def start(self):
|
||
"""启动调度器"""
|
||
if not self.scheduler.running:
|
||
# 添加定时检查任务
|
||
self.scheduler.add_job(
|
||
self.check_and_sync_reports,
|
||
IntervalTrigger(hours=settings.SCHEDULER_INTERVAL_HOURS),
|
||
id="sync_reports",
|
||
name="同步公司报告",
|
||
replace_existing=True
|
||
)
|
||
|
||
self.scheduler.start()
|
||
self.is_running = True
|
||
logger.info(f"定时任务调度器已启动,间隔: {settings.SCHEDULER_INTERVAL_HOURS}小时")
|
||
|
||
def stop(self):
|
||
"""停止调度器"""
|
||
if self.scheduler.running:
|
||
self.scheduler.shutdown()
|
||
self.is_running = False
|
||
logger.info("定时任务调度器已停止")
|
||
|
||
def get_status(self) -> dict:
|
||
"""获取调度器状态"""
|
||
next_run = None
|
||
job = self.scheduler.get_job("sync_reports")
|
||
if job:
|
||
next_run = job.next_run_time
|
||
|
||
return {
|
||
"is_running": self.is_running,
|
||
"next_run_time": next_run,
|
||
"interval_hours": settings.SCHEDULER_INTERVAL_HOURS,
|
||
"last_run_time": self.last_run_time,
|
||
"last_run_status": self.last_run_status
|
||
}
|
||
|
||
async def check_and_sync_reports(self):
|
||
"""检查并同步所有公司的报告 - 分两阶段执行"""
|
||
logger.info("开始执行定时同步任务")
|
||
self.last_run_time = datetime.now(timezone.utc)
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
# 记录任务开始
|
||
task_log = TaskLog(
|
||
task_type="crawl",
|
||
task_name="定时同步报告",
|
||
status="started",
|
||
message="开始同步所有公司报告"
|
||
)
|
||
db.add(task_log)
|
||
await db.commit()
|
||
|
||
try:
|
||
# ========== 阶段1: 同步所有公司的报告 ==========
|
||
logger.info("========== 阶段1: 同步报告 ==========")
|
||
stmt = select(Company).where(Company.is_active == True)
|
||
result = await db.execute(stmt)
|
||
companies = result.scalars().all()
|
||
|
||
total_new = 0
|
||
for company in companies:
|
||
try:
|
||
new_count = await cninfo_service.sync_company_reports(db, company)
|
||
total_new += new_count
|
||
except Exception as e:
|
||
logger.error(f"同步 {company.stock_code} 失败: {e}")
|
||
continue
|
||
|
||
logger.info(f"阶段1完成: 共新增 {total_new} 份报告")
|
||
|
||
# ========== 阶段2: 提取和分析新报告(后台执行,不阻塞API) ==========
|
||
from sqlalchemy.orm import selectinload
|
||
stmt = select(Report).where(
|
||
Report.is_downloaded == True,
|
||
Report.is_analyzed == False
|
||
).options(selectinload(Report.company))
|
||
result = await db.execute(stmt)
|
||
pending_reports = result.scalars().all()
|
||
|
||
pending_count = len(pending_reports)
|
||
logger.info(f"待处理报告: {pending_count} 份,将在后台分析")
|
||
|
||
# 更新任务日志(阶段1完成,阶段2后台执行)
|
||
task_log.status = "completed"
|
||
task_log.message = f"同步完成: 新增 {total_new} 份报告,{pending_count} 份待分析"
|
||
task_log.completed_at = datetime.now(timezone.utc)
|
||
self.last_run_status = "success"
|
||
await db.commit()
|
||
|
||
# 启动后台分析任务(不阻塞当前请求)
|
||
if pending_reports:
|
||
import asyncio
|
||
asyncio.create_task(self._background_analyze(
|
||
[r.id for r in pending_reports]
|
||
))
|
||
|
||
except Exception as e:
|
||
logger.error(f"定时任务执行失败: {e}")
|
||
task_log.status = "failed"
|
||
task_log.error = str(e)
|
||
task_log.completed_at = datetime.now(timezone.utc)
|
||
self.last_run_status = "failed"
|
||
await db.commit()
|
||
|
||
logger.info("定时同步任务执行完成")
|
||
|
||
async def _background_analyze(self, report_ids: list):
|
||
"""后台分析任务(不阻塞API)"""
|
||
logger.info(f"========== 后台分析任务启动: {len(report_ids)} 份报告 ==========")
|
||
|
||
analyzed_count = 0
|
||
for report_id in report_ids:
|
||
try:
|
||
async with AsyncSessionLocal() as db:
|
||
from sqlalchemy.orm import selectinload
|
||
stmt = select(Report).where(Report.id == report_id).options(
|
||
selectinload(Report.company)
|
||
)
|
||
result = await db.execute(stmt)
|
||
report = result.scalar_one_or_none()
|
||
|
||
if report and not report.is_analyzed:
|
||
logger.info(f"后台分析: {report.title}")
|
||
|
||
# 先提取内容
|
||
if not report.is_extracted:
|
||
await pdf_extractor.extract_and_save(db, report)
|
||
|
||
# 再分析
|
||
success = await ai_analyzer.analyze_report(db, report)
|
||
if success:
|
||
analyzed_count += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"后台分析报告 {report_id} 失败: {e}")
|
||
continue
|
||
|
||
logger.info(f"========== 后台分析任务完成: 成功 {analyzed_count}/{len(report_ids)} 份 ==========")
|
||
|
||
async def run_once(self):
|
||
"""立即执行一次同步"""
|
||
await self.check_and_sync_reports()
|
||
|
||
|
||
scheduler_service = SchedulerService()
|