198 lines
7.5 KiB
Python
198 lines
7.5 KiB
Python
|
|
"""
|
|||
|
|
AI质量评估服务 - 多评判师机制
|
|||
|
|
使用多个AI评判师对分析结果进行质量评估
|
|||
|
|
通过共识算法决定是否需要重新生成
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
from typing import List, Dict, Optional, Tuple
|
|||
|
|
import httpx
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
|
|||
|
|
from app.config import settings
|
|||
|
|
from app.utils.logger import logger
|
|||
|
|
from app.models import AnalysisResult
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AIJudge:
|
|||
|
|
"""AI评判师"""
|
|||
|
|
|
|||
|
|
def __init__(self, judge_id: str, personality: str):
|
|||
|
|
self.judge_id = judge_id
|
|||
|
|
self.personality = personality
|
|||
|
|
self.api_url = settings.AI_API_URL
|
|||
|
|
self.api_key = settings.AI_API_KEY
|
|||
|
|
self.model = settings.AI_MODEL
|
|||
|
|
self.timeout = httpx.Timeout(60.0)
|
|||
|
|
|
|||
|
|
async def evaluate(self, content: str, criteria: Dict) -> Dict:
|
|||
|
|
"""评估内容质量"""
|
|||
|
|
system_prompt = f"""你是一位{self.personality}的分析质量评审专家。
|
|||
|
|
请严格按照评估标准对以下分析内容进行打分。
|
|||
|
|
只返回JSON格式的评分结果,不要有其他文字。"""
|
|||
|
|
|
|||
|
|
prompt = f"""请评估以下AI生成的分析内容质量:
|
|||
|
|
|
|||
|
|
---分析内容---
|
|||
|
|
{content[:4000]}
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
评估标准:
|
|||
|
|
1. 完整性 (completeness): 是否涵盖了所有关键要点 (1-10分)
|
|||
|
|
2. 准确性 (accuracy): 分析是否准确、有依据 (1-10分)
|
|||
|
|
3. 洞察力 (insight): 是否提供了有价值的洞察 (1-10分)
|
|||
|
|
4. 可读性 (readability): 结构是否清晰、易于理解 (1-10分)
|
|||
|
|
5. 专业性 (professionalism): 是否使用了专业的分析方法 (1-10分)
|
|||
|
|
|
|||
|
|
请严格按以下JSON格式返回评分:
|
|||
|
|
{{"completeness": 8, "accuracy": 7, "insight": 6, "readability": 9, "professionalism": 8, "overall": 7.6, "comments": "简短评语"}}"""
|
|||
|
|
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": prompt}
|
|||
|
|
],
|
|||
|
|
"temperature": 0.2,
|
|||
|
|
"max_tokens": 500
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|||
|
|
response = await client.post(self.api_url, headers=headers, json=payload)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
data = response.json()
|
|||
|
|
content = data["choices"][0]["message"]["content"]
|
|||
|
|
|
|||
|
|
# 解析JSON评分
|
|||
|
|
import json
|
|||
|
|
try:
|
|||
|
|
# 尝试提取JSON部分
|
|||
|
|
json_match = content
|
|||
|
|
if "{" in content:
|
|||
|
|
start = content.index("{")
|
|||
|
|
end = content.rindex("}") + 1
|
|||
|
|
json_match = content[start:end]
|
|||
|
|
scores = json.loads(json_match)
|
|||
|
|
scores["judge_id"] = self.judge_id
|
|||
|
|
return {"success": True, "scores": scores}
|
|||
|
|
except:
|
|||
|
|
logger.warning(f"评判师 {self.judge_id} 返回格式异常: {content[:100]}")
|
|||
|
|
return {"success": False, "error": "JSON解析失败"}
|
|||
|
|
else:
|
|||
|
|
return {"success": False, "error": f"API错误: {response.status_code}"}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"评判师 {self.judge_id} 评估失败: {e}")
|
|||
|
|
return {"success": False, "error": str(e)}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QualityEvaluator:
|
|||
|
|
"""质量评估器 - 多评判师共识机制"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
# 创建3个不同风格的评判师
|
|||
|
|
self.judges = [
|
|||
|
|
AIJudge("judge_strict", "严格且注重细节"),
|
|||
|
|
AIJudge("judge_balanced", "平衡且客观"),
|
|||
|
|
AIJudge("judge_practical", "注重实用性和商业价值")
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 质量阈值
|
|||
|
|
self.quality_threshold = 6.5 # 总分10分,低于6.5分需重新生成
|
|||
|
|
self.consensus_threshold = 2 # 至少需要2个评判师达成共识
|
|||
|
|
|
|||
|
|
async def evaluate_analysis(self, analysis_content: str) -> Dict:
|
|||
|
|
"""
|
|||
|
|
使用多个评判师评估分析质量
|
|||
|
|
返回: {passed: bool, avg_score: float, should_regenerate: bool, details: [...]}
|
|||
|
|
"""
|
|||
|
|
logger.info("开始多评判师质量评估...")
|
|||
|
|
|
|||
|
|
# 并行调用所有评判师
|
|||
|
|
tasks = [judge.evaluate(analysis_content, {}) for judge in self.judges]
|
|||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|||
|
|
|
|||
|
|
# 收集有效评分
|
|||
|
|
valid_scores = []
|
|||
|
|
for i, result in enumerate(results):
|
|||
|
|
if isinstance(result, Exception):
|
|||
|
|
logger.error(f"评判师异常: {result}")
|
|||
|
|
continue
|
|||
|
|
if result.get("success"):
|
|||
|
|
valid_scores.append(result["scores"])
|
|||
|
|
logger.info(f"评判师 {result['scores']['judge_id']}: 总分 {result['scores'].get('overall', 0)}")
|
|||
|
|
|
|||
|
|
if len(valid_scores) < 2:
|
|||
|
|
logger.warning("有效评分不足,跳过质量评估")
|
|||
|
|
return {
|
|||
|
|
"passed": True, # 评分不足时默认通过
|
|||
|
|
"avg_score": 0,
|
|||
|
|
"should_regenerate": False,
|
|||
|
|
"details": [],
|
|||
|
|
"reason": "评判师响应不足"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 计算平均分
|
|||
|
|
avg_overall = sum(s.get("overall", 0) for s in valid_scores) / len(valid_scores)
|
|||
|
|
|
|||
|
|
# 计算各维度平均分
|
|||
|
|
dimensions = ["completeness", "accuracy", "insight", "readability", "professionalism"]
|
|||
|
|
avg_dimensions = {}
|
|||
|
|
for dim in dimensions:
|
|||
|
|
dim_scores = [s.get(dim, 0) for s in valid_scores]
|
|||
|
|
avg_dimensions[dim] = sum(dim_scores) / len(dim_scores)
|
|||
|
|
|
|||
|
|
# 共识判断:多少评判师认为通过
|
|||
|
|
pass_votes = sum(1 for s in valid_scores if s.get("overall", 0) >= self.quality_threshold)
|
|||
|
|
|
|||
|
|
passed = pass_votes >= self.consensus_threshold
|
|||
|
|
should_regenerate = not passed
|
|||
|
|
|
|||
|
|
logger.info(f"质量评估完成: 平均分 {avg_overall:.2f}, 通过票数 {pass_votes}/{len(valid_scores)}, 需重新生成: {should_regenerate}")
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"passed": passed,
|
|||
|
|
"avg_score": round(avg_overall, 2),
|
|||
|
|
"pass_votes": pass_votes,
|
|||
|
|
"total_judges": len(valid_scores),
|
|||
|
|
"should_regenerate": should_regenerate,
|
|||
|
|
"dimension_scores": avg_dimensions,
|
|||
|
|
"details": valid_scores,
|
|||
|
|
"reason": "达到质量标准" if passed else "质量不达标,建议重新生成"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def evaluate_and_decide(
|
|||
|
|
self,
|
|||
|
|
db: AsyncSession,
|
|||
|
|
analysis_result: AnalysisResult,
|
|||
|
|
max_retries: int = 2
|
|||
|
|
) -> Tuple[bool, Dict]:
|
|||
|
|
"""
|
|||
|
|
评估分析结果并决定是否需要重新生成
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db: 数据库会话
|
|||
|
|
analysis_result: 分析结果对象
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(is_acceptable, evaluation_details)
|
|||
|
|
"""
|
|||
|
|
evaluation = await self.evaluate_analysis(analysis_result.summary)
|
|||
|
|
|
|||
|
|
# 保存评估结果到分析记录(可扩展字段存储)
|
|||
|
|
# analysis_result.evaluation_score = evaluation["avg_score"]
|
|||
|
|
# await db.commit()
|
|||
|
|
|
|||
|
|
return evaluation["passed"], evaluation
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 创建全局实例
|
|||
|
|
quality_evaluator = QualityEvaluator()
|