369 lines
15 KiB
Python
369 lines
15 KiB
Python
"""
|
||
AI分析服务 - 多AI分层分析
|
||
"""
|
||
import json
|
||
import asyncio
|
||
from typing import List, Dict, Optional
|
||
import httpx
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
|
||
from app.config import settings
|
||
from app.utils.logger import logger
|
||
from app.models import Report, ExtractedContent, AnalysisResult, AnalysisStatus
|
||
|
||
|
||
class AIAnalyzer:
|
||
"""AI分析服务"""
|
||
|
||
def __init__(self):
|
||
self.api_url = settings.AI_API_URL
|
||
self.api_key = settings.AI_API_KEY
|
||
self.model = settings.AI_MODEL
|
||
self.parallel_count = settings.AI_PARALLEL_COUNT
|
||
self.timeout = httpx.Timeout(600.0) # 增加到10分钟超时
|
||
# 复用 client 以提升性能
|
||
self.client = httpx.AsyncClient(timeout=self.timeout)
|
||
|
||
async def __del__(self):
|
||
await self.client.aclose()
|
||
|
||
async def call_ai(self, prompt: str, system_prompt: str = None) -> Dict:
|
||
"""调用AI接口(带重试机制) - 复用Client"""
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"temperature": 0.3,
|
||
"max_tokens": 4096
|
||
}
|
||
|
||
max_retries = 3
|
||
retry_delays = [2, 5, 10]
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
# 使用复用的 client
|
||
response = await self.client.post(
|
||
self.api_url,
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
content = data["choices"][0]["message"]["content"]
|
||
tokens = data.get("usage", {}).get("total_tokens", 0)
|
||
return {"success": True, "content": content, "tokens": tokens}
|
||
else:
|
||
error_msg = f"AI调用失败: {response.status_code} {response.text}"
|
||
logger.error(error_msg)
|
||
if response.status_code < 500:
|
||
return {"success": False, "error": response.text}
|
||
raise Exception(error_msg)
|
||
|
||
except Exception as e:
|
||
if attempt < max_retries:
|
||
delay = retry_delays[attempt]
|
||
logger.warning(f"AI调用异常: {e},{delay}秒后重试...")
|
||
await asyncio.sleep(delay)
|
||
else:
|
||
logger.error(f"AI调用异常: {e}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
async def analyze_section(self, section: ExtractedContent, company_name: str) -> Dict:
|
||
"""分析单个章节 - 定制化电商+战略扩展+数字化视角"""
|
||
keyword = section.section_keyword
|
||
|
||
# 判断内容类型
|
||
is_financial = any(kw in keyword for kw in ["资产负债表", "利润表", "现金流量表", "财务数据", "财务指标"])
|
||
is_operation = any(kw in keyword for kw in ["主营业务", "经营情况", "收入", "成本", "毛利率", "费用", "研发"])
|
||
is_market = any(kw in keyword for kw in ["行业", "市场", "竞争", "份额"])
|
||
|
||
# 系统提示词
|
||
system_prompt = """你是一位服务于亿级电商眼镜商家(全平台、多品牌)的战略顾问。
|
||
客户现状:主营太阳镜/运动/光学镜,年收1亿。正在进行渠道扩张(分销)和数字化转型。
|
||
你的任务:基于上市公司年报原文,为客户提供实战参考和战略情报。
|
||
|
||
【核心准则】:
|
||
1. **身份定位**:你是顾问,要提供不仅是数据,更是业务视角的洞察。
|
||
2. **证据至上**:你的所有结论必须直接引自或推导自原文事实。
|
||
3. **拒绝脑补**:严禁任何形式的脱离原文的“反向推导”、猜测或幻觉。
|
||
4. **诚实声明**:如果关键信息缺失,请直接指出,严禁臆测。"""
|
||
|
||
if is_financial:
|
||
analysis_prompt = f"""作为顾问,请研读{company_name}的财务数据片段,结合客户(年收1亿且在转型)的需求进行分析:
|
||
|
||
---
|
||
{section.content[:10000]}
|
||
---
|
||
|
||
任务:
|
||
1. **真实IT底色**:从报表中找出确凿的系统/软件/研发投入。
|
||
2. **经营稳定性**:通过存货和利润数据,分析其作为龙头的真实盈利质量。
|
||
3. **投入产出比**:分析其费用结构是否合理。
|
||
|
||
注意:仅分析原文,绝不编造。"""
|
||
|
||
elif is_operation:
|
||
analysis_prompt = f"""作为顾问,请从{company_name}的经营描述中,拆解其对客户有价值的实战战术:
|
||
|
||
---
|
||
{section.content[:10000]}
|
||
---
|
||
|
||
任务:
|
||
1. **真实数字化**:原文中是否明确提到了使用的系统?解决了什么业务问题?
|
||
2. **渠道体系**:其分销/加盟/直营模式的真实描述是什么?
|
||
3. **品类逻辑**:他们在什么品类上投入最重?
|
||
|
||
注意:严禁强行关联数字化。"""
|
||
|
||
elif is_market:
|
||
analysis_prompt = f"""作为顾问,请分析{company_name}眼中的市场局局势,为客户的战略布局提供方向:
|
||
|
||
---
|
||
{section.content[:8000]}
|
||
---
|
||
|
||
任务:
|
||
1. **核心壁垒**:大厂公开承认的竞争优势是什么?
|
||
2. **行业机会**:原文中他们重点看好哪些细分增长点?
|
||
|
||
注意:仅提炼策略。"""
|
||
|
||
else:
|
||
analysis_prompt = f"""作为顾问,分析{company_name}关于「{keyword}」的真实战略图谋:
|
||
|
||
---
|
||
{section.content[:8000]}
|
||
---
|
||
|
||
任务:
|
||
1. **明确投向**:未来大厂明确要投入资源解决的问题是什么?
|
||
2. **战略排序**:他们的重心是在“求快”还是“求稳”?
|
||
|
||
注意:必须基于原文描述。"""
|
||
|
||
result = await self.call_ai(analysis_prompt, system_prompt)
|
||
return {
|
||
"section_name": section.section_name,
|
||
"keyword": section.section_keyword,
|
||
"analysis": result.get("content", ""),
|
||
"tokens": result.get("tokens", 0),
|
||
"success": result.get("success", False),
|
||
"content_type": "financial" if is_financial else ("operation" if is_operation else ("market" if is_market else "strategy"))
|
||
}
|
||
|
||
async def summarize_analyses(self, analyses: List[Dict], company_name: str, report_title: str) -> Dict:
|
||
"""汇总分析 - 生成《战略扩张与数字化内参》"""
|
||
system_prompt = """你是一位专注于零售品牌增长的战略顾问。
|
||
客户是年收1亿的电商眼镜品牌,正处于“拓展分销”、“拓新品类”和“数字化转型”的关键期。
|
||
请基于上市公司年报,为客户提供一份具备战略高度的扩张与升级指南。"""
|
||
|
||
analyses_text = ""
|
||
# 财务类情报
|
||
fin_insights = [a['analysis'] for a in analyses if a.get('success') and a.get('content_type') == 'financial']
|
||
analyses_text += "【财务情报】\n" + "\n".join(fin_insights) + "\n\n"
|
||
|
||
# 运营类情报
|
||
op_insights = [a['analysis'] for a in analyses if a.get('success') and a.get('content_type') == 'operation']
|
||
analyses_text += "【运营实战】\n" + "\n".join(op_insights) + "\n\n"
|
||
|
||
# 市场/战略情报
|
||
mkt_insights = [a['analysis'] for a in analyses if a.get('success') and a.get('content_type') in ['market', 'strategy']]
|
||
analyses_text += "【战略情报】\n" + "\n".join(mkt_insights)
|
||
|
||
prompt = f"""以下是从{company_name}《{report_title}》提炼出的关键情报:
|
||
|
||
{analyses_text}
|
||
|
||
请合成一份**《高密度实战内参》**。
|
||
|
||
【内参结构】:
|
||
## 📊 关键指标对标局
|
||
## 🛠 数字化实战方案
|
||
## 🌍 全渠道与分销战法
|
||
## 🎯 新机会雷达
|
||
## 💡 顾问最终建议(3条)"""
|
||
|
||
result = await self.call_ai(prompt, system_prompt)
|
||
return {
|
||
"summary": result.get("content", ""),
|
||
"tokens": result.get("tokens", 0),
|
||
"success": result.get("success", False)
|
||
}
|
||
|
||
async def analyze_report(self, db: AsyncSession, report: Report, force: bool = False) -> bool:
|
||
"""分析完整报告"""
|
||
if report.is_analyzed and not force:
|
||
logger.info(f"报告已分析: {report.title}")
|
||
return True
|
||
|
||
# 更新状态
|
||
report.analysis_status = AnalysisStatus.ANALYZING.value
|
||
await db.commit()
|
||
|
||
# 获取提取的内容
|
||
stmt = select(ExtractedContent).where(ExtractedContent.report_id == report.id)
|
||
result = await db.execute(stmt)
|
||
contents = result.scalars().all()
|
||
|
||
if not contents:
|
||
logger.info(f"报告无提取内容,尝试自动提取: {report.title}")
|
||
from app.services.pdf_extractor import pdf_extractor
|
||
try:
|
||
contents = await pdf_extractor.extract_and_save(db, report)
|
||
if not contents:
|
||
report.analysis_status = AnalysisStatus.FAILED.value
|
||
await db.commit()
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"自动提取异常: {e}")
|
||
report.analysis_status = AnalysisStatus.FAILED.value
|
||
await db.commit()
|
||
return False
|
||
|
||
from app.models import Company
|
||
company_stmt = select(Company).where(Company.id == report.company_id)
|
||
company_result = await db.execute(company_stmt)
|
||
company = company_result.scalar_one_or_none()
|
||
company_name = company.short_name or company.company_name if company else "未知公司"
|
||
|
||
logger.info(f"开始分析报告: {report.title}, 共 {len(contents)} 个章节")
|
||
|
||
# 清除旧结果
|
||
try:
|
||
from sqlalchemy import delete
|
||
stmt = delete(AnalysisResult).where(AnalysisResult.report_id == report.id)
|
||
await db.execute(stmt)
|
||
await db.commit()
|
||
except Exception as e:
|
||
logger.error(f"清除旧结果失败: {e}")
|
||
|
||
# 使用 Worker 队列模式而非 gather
|
||
queue = asyncio.Queue()
|
||
for i, c in enumerate(contents):
|
||
queue.put_nowait((i, c))
|
||
|
||
valid_results = []
|
||
# 准备一个专门用于增量保存的锁,防止 session 并发提交冲突
|
||
db_lock = asyncio.Lock()
|
||
|
||
async def worker():
|
||
while not queue.empty():
|
||
try:
|
||
# 使用 get_nowait 避免在此处挂起等待
|
||
idx, content = queue.get_nowait()
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
try:
|
||
# 主动让出,确保不阻塞主线程
|
||
await asyncio.sleep(0.5)
|
||
logger.info(f"Worker处理章节 [{idx+1}/{len(contents)}]: {content.section_name}")
|
||
|
||
res = await self.analyze_section(content, company_name)
|
||
|
||
if res.get("success"):
|
||
async with db_lock:
|
||
try:
|
||
from app.database import AsyncSessionLocal
|
||
async with AsyncSessionLocal() as session:
|
||
analysis = AnalysisResult(
|
||
report_id=report.id,
|
||
analysis_type="section",
|
||
section_name=res["section_name"],
|
||
ai_model=self.model,
|
||
summary=res["analysis"],
|
||
token_count=res["tokens"],
|
||
is_final=False
|
||
)
|
||
session.add(analysis)
|
||
await session.commit()
|
||
valid_results.append(res)
|
||
except Exception as e:
|
||
logger.error(f"章节增量保存失败: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Worker异常: {e}")
|
||
finally:
|
||
queue.task_done()
|
||
|
||
# 启动 Consumer Workers
|
||
workers = [asyncio.create_task(worker()) for _ in range(self.parallel_count)]
|
||
await asyncio.gather(*workers)
|
||
|
||
# 汇总分析
|
||
report.analysis_status = AnalysisStatus.SUMMARIZING.value
|
||
await db.commit()
|
||
|
||
if valid_results:
|
||
summary = await self.summarize_analyses(valid_results, company_name, report.title)
|
||
if summary.get("success"):
|
||
final_analysis = AnalysisResult(
|
||
report_id=report.id,
|
||
analysis_type="summary",
|
||
section_name="综合分析",
|
||
ai_model=self.model,
|
||
summary=summary["summary"],
|
||
token_count=summary.get("tokens", 0),
|
||
is_final=True
|
||
)
|
||
db.add(final_analysis)
|
||
|
||
# 更新报告状态
|
||
report.is_analyzed = True
|
||
report.analysis_status = AnalysisStatus.COMPLETED.value
|
||
await db.commit()
|
||
return True
|
||
|
||
async def analyze_with_quality_check(
|
||
self,
|
||
db: AsyncSession,
|
||
report: Report,
|
||
enable_quality_check: bool = False
|
||
) -> bool:
|
||
"""带质量检查的分析(可选启用多评判师评估)"""
|
||
result = await self.analyze_report(db, report)
|
||
|
||
if result and enable_quality_check:
|
||
try:
|
||
from app.services.quality_evaluator import quality_evaluator
|
||
|
||
# 获取刚生成的汇总
|
||
stmt = select(AnalysisResult).where(
|
||
AnalysisResult.report_id == report.id,
|
||
AnalysisResult.analysis_type == "summary"
|
||
)
|
||
from sqlalchemy import select
|
||
result = await db.execute(stmt)
|
||
summary_result = result.scalar_one_or_none()
|
||
|
||
if summary_result:
|
||
passed, evaluation = await quality_evaluator.evaluate_and_decide(
|
||
db, summary_result
|
||
)
|
||
logger.info(f"质量评估: 通过={passed}, 评分={evaluation.get('avg_score', 0)}")
|
||
|
||
if not passed:
|
||
logger.warning(f"质量不达标,建议人工复核: {report.title}")
|
||
|
||
except ImportError:
|
||
logger.warning("质量评估模块未加载")
|
||
|
||
return result
|
||
|
||
|
||
ai_analyzer = AIAnalyzer()
|
||
|