ai_v/services/task_service.py
24024 158ba123b1 ```
feat(task_service): 增强任务服务的错误处理和重试机制

- 添加了 _extract_error_detail 函数,用于从 API 响应中智能提取详细的错误信息,
  支持多种常见的错误字段格式,提高错误诊断准确性

- 集成 requests 异常处理,区分连接超时、连接错误和读取超时等不同类型的网络异常,
  实现更精确的重试策略,避免因响应丢失导致的任务重复提交

- 在图像和视频生成流程中统一使用新的错误提取函数替代原有的简单错误字段获取,
  提升失败任务的错误信息详细程度

- 优化异常处理逻辑,对不同类型的异常采用相应的处理策略,包括安全重试和终止重试
```
2026-03-12 23:25:46 +08:00

446 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
import json
import requests
import io
import time
import base64
import threading
from urllib.parse import quote
from requests.exceptions import ConnectionError, ConnectTimeout
from PIL import Image
from extensions import s3_client, redis_client, db
from models import GenerationRecord, User
from config import Config
from services.logger import system_logger
from utils import get_proxied_url
def _extract_error_detail(data, default="未知错误"):
"""从API响应中提取详细的错误信息"""
if not isinstance(data, dict):
return default
# 按优先级尝试多个常见的错误字段
candidates = [
data.get('fail_reason'),
data.get('error', {}).get('message') if isinstance(data.get('error'), dict) else data.get('error'),
data.get('message'),
data.get('detail'),
data.get('data', {}).get('fail_reason') if isinstance(data.get('data'), dict) else None,
data.get('data', {}).get('message') if isinstance(data.get('data'), dict) else None,
data.get('data', {}).get('error') if isinstance(data.get('data'), dict) else None,
]
for c in candidates:
if c and isinstance(c, str) and c.strip():
return c.strip()
# 如果所有字段都没有,返回响应摘要帮助排查
# 过滤掉大字段,只保留状态类信息
summary_keys = {'status', 'code', 'error_code', 'fail_code', 'type'}
summary = {k: v for k, v in data.items() if k in summary_keys and v}
if summary:
return f"{default} ({json.dumps(summary, ensure_ascii=False)})"
return default
def sync_images_background(app, record_id, raw_urls):
"""后台同步图片至 MinIO并生成缩略图带重试机制"""
with app.app_context():
processed_data = []
for raw_url in raw_urls:
success = False
for attempt in range(3): # 3 次重试机制
try:
img_resp = requests.get(get_proxied_url(raw_url), timeout=Config.PROXY_TIMEOUT_SHORT)
if img_resp.status_code == 200:
content = img_resp.content
ext = ".png"
base_filename = f"gen-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
thumb_filename = f"{base_filename}-thumb{ext}"
# 1. 上传原图
s3_client.upload_fileobj(
io.BytesIO(content),
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": "image/png"}
)
full_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
thumb_url = full_url # 默认使用原图
# 2. 生成并上传缩略图 (400px 宽度)
try:
img = Image.open(io.BytesIO(content))
# 转换为 RGB 如果是 RGBA (避免某些格式保存问题)
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
# 缩放至宽度 400, 高度等比
w, h = img.size
if w > 400:
ratio = 400 / float(w)
img.thumbnail((400, int(h * ratio)), Image.Resampling.LANCZOS)
thumb_io = io.BytesIO()
# 缩略图保存为 JPEG 以获得更小的体积
img.save(thumb_io, format='JPEG', quality=80, optimize=True)
thumb_io.seek(0)
s3_client.upload_fileobj(
thumb_io,
Config.MINIO["bucket"],
thumb_filename.replace('.png', '.jpg'),
ExtraArgs={"ContentType": "image/jpeg"}
)
thumb_url = f"{Config.MINIO['public_url']}{quote(thumb_filename.replace('.png', '.jpg'))}"
except Exception as thumb_e:
print(f"⚠️ 缩略图生成失败: {thumb_e}")
processed_data.append({"url": full_url, "thumb": thumb_url})
success = True
break
except Exception as e:
print(f"⚠️ 第 {attempt+1} 次同步失败: {e}")
time.sleep(2 ** attempt) # 指数退避
if not success:
# 如果最终失败,保留原始 URL
processed_data.append({"url": raw_url, "thumb": raw_url})
# 更新数据库记录为持久化数据结构
try:
record = db.session.get(GenerationRecord, record_id)
if record:
record.image_urls = json.dumps(processed_data)
db.session.commit()
print(f"✅ 记录 {record_id} 图片及缩略图已完成同步")
except Exception as e:
print(f"❌ 更新记录失败: {e}")
def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost, use_trial=False):
"""异步执行图片生成并存入 Redis (支持异步任务)"""
with app.app_context():
# 更新状态为处理中
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "processing", "message": "任务已提交,正在排队处理..."}))
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 1. 提交异步请求 (带重试机制)
submit_resp = None
last_error = None
for attempt in range(3):
try:
# 添加 async=true 参数启用异步模式
submit_resp = requests.post(
get_proxied_url(target_api),
params={"async": "true"},
json=payload,
headers=headers,
timeout=Config.PROXY_TIMEOUT_DEFAULT
)
if submit_resp.status_code == 200:
break # 成功
else:
system_logger.warning(f"任务提交失败(第{attempt+1}次): {submit_resp.status_code} - {submit_resp.text[:100]}", user_id=user_id, task_id=task_id)
last_error = f"HTTP {submit_resp.status_code}: {submit_resp.text}"
except (ConnectTimeout, ConnectionError) as e:
# 连接阶段的错误(DNS失败、连接拒绝、连接超时),请求大概率未发出,可安全重试
err_str = str(e)
if 'RemoteDisconnected' in err_str or 'ReadTimeout' in err_str:
# 连接建立后断开 = 请求可能已发送到上游,不要重试以免重复任务
system_logger.warning(f"任务提交后响应丢失(第{attempt+1}次),不再重试: {err_str}", user_id=user_id, task_id=task_id)
last_error = err_str
break
system_logger.warning(f"任务提交连接异常(第{attempt+1}次): {err_str}", user_id=user_id, task_id=task_id)
last_error = err_str
time.sleep(1)
except Exception as e:
# 其他未知异常,保守起见不重试
system_logger.warning(f"任务提交异常(第{attempt+1}次): {str(e)}", user_id=user_id, task_id=task_id)
last_error = str(e)
break
if not submit_resp or submit_resp.status_code != 200:
raise Exception(f"任务提交失败(重试3次后): {last_error}")
submit_result = submit_resp.json()
# 判断是否返回了 task_id (异步模式)
raw_urls = []
if 'task_id' in submit_result:
remote_task_id = submit_result['task_id']
system_logger.info(f"外部异步任务已提交: {remote_task_id}", user_id=user_id, task_id=task_id)
# 构造查询 URL: .../images/generations -> .../images/tasks/{task_id}
poll_url = target_api.replace('/generations', f'/tasks/{remote_task_id}')
if poll_url == target_api: # Fallback if replace failed
import posixpath
base_url = posixpath.dirname(target_api)
poll_url = f"{base_url}/tasks/{remote_task_id}"
system_logger.info(f"开始轮询异步任务: {poll_url}", user_id=user_id, task_id=task_id)
# 2. 轮询状态
max_retries = 600 # 30分钟超时 (平均3s)
generation_success = False
for i in range(max_retries):
# 动态调整轮询间隔前15次(约15秒) 1秒一次之后 3秒一次
sleep_time = 1 if i < 15 else 3
if i > 0:
time.sleep(sleep_time)
# 更新本地心跳
if i % 5 == 0:
elapsed = i if i < 15 else (15 + (i-15)*3)
redis_client.setex(f"task:{task_id}", 3600, json.dumps({
"status": "processing",
"message": f"正在生成中 (已耗时 {elapsed} 秒)..."
}))
try:
poll_resp = requests.get(get_proxied_url(poll_url), headers=headers, timeout=Config.PROXY_TIMEOUT_SHORT)
if poll_resp.status_code != 200:
system_logger.warning(f"轮询非 200: {poll_resp.status_code}", user_id=user_id, task_id=task_id)
continue
poll_data = poll_resp.json()
remote_status = poll_data.get('status')
if not remote_status and 'data' in poll_data and isinstance(poll_data['data'], dict):
remote_status = poll_data['data'].get('status')
if remote_status == 'SUCCESS':
# 解析结果 (增强鲁棒性)
data_node = poll_data.get('data')
raw_urls = []
if isinstance(data_node, dict):
# 尝试多层级查找 data.data.data
inner_node = data_node.get('data')
if isinstance(inner_node, dict) and 'data' in inner_node and isinstance(inner_node['data'], list):
# data -> data -> data -> [...] (Comfly structure)
raw_urls = [item.get('url') for item in inner_node['data'] if isinstance(item, dict) and item.get('url')]
elif isinstance(inner_node, list):
# data -> data -> [...] (Standard)
raw_urls = [item.get('url') for item in inner_node if isinstance(item, dict) and item.get('url')]
elif 'url' in data_node:
raw_urls = [data_node['url']]
elif isinstance(data_node, list):
# data -> [...]
raw_urls = [item.get('url') for item in data_node if isinstance(item, dict) and item.get('url')]
# Fallback: check for top-level url
if not raw_urls and 'url' in poll_data:
raw_urls = [poll_data['url']]
if raw_urls:
generation_success = True
break
elif remote_status == 'FAILURE':
raise Exception(f"生成任务失败: {_extract_error_detail(poll_data)}")
except requests.RequestException:
continue # 网络波动重试
if not generation_success:
raise Exception("生成任务超时或未获取到结果")
else:
# 兼容旧的同步返回模式
raw_urls = [item['url'] for item in submit_result.get('data', [])]
if not raw_urls:
raise Exception("未获取到图片地址")
# 3. 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
cost=cost,
image_urls=json.dumps(raw_urls)
)
db.session.add(new_record)
db.session.commit()
# 4. 后台线程同步 MinIO
threading.Thread(
target=sync_images_background,
args=(app, new_record.id, raw_urls)
).start()
# 5. 完成
system_logger.info(f"生图任务完成", user_id=user_id, task_id=task_id, model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "complete", "urls": raw_urls}))
except Exception as e:
# 异常处理:退还积分
if use_trial:
from services.generation_service import refund_points
refund_points(user_id, cost)
system_logger.error(f"生图任务异常: {str(e)}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))
def sync_video_background(app, record_id, raw_url, internal_task_id=None):
"""后台同步视频至 MinIO带重试机制"""
with app.app_context():
success = False
final_url = raw_url
for attempt in range(3):
try:
# 增加了流式下载,处理大视频文件
with requests.get(get_proxied_url(raw_url), stream=True, timeout=Config.PROXY_TIMEOUT_LONG) as r:
r.raise_for_status()
content_type = r.headers.get('content-type', 'video/mp4')
ext = ".mp4"
if "text/html" in content_type: # 有些 API 返回的是跳转页面
continue
base_filename = f"video-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
video_io = io.BytesIO()
for chunk in r.iter_content(chunk_size=8192):
video_io.write(chunk)
video_io.seek(0)
# 上传至 MinIO
s3_client.upload_fileobj(
video_io,
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": content_type}
)
final_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
success = True
break
except Exception as e:
system_logger.error(f"同步视频失败 (第{attempt+1}次): {str(e)}")
time.sleep(5)
if success:
try:
record = db.session.get(GenerationRecord, record_id)
if record:
# 更新记录为 MinIO 的 URL
record.image_urls = json.dumps([{"url": final_url, "type": "video"}])
db.session.commit()
# 同步更新 Redis 中的缓存
if internal_task_id:
cached_data = redis_client.get(f"task:{internal_task_id}")
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode('utf-8')
task_info = json.loads(cached_data)
task_info['video_url'] = final_url
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps(task_info))
system_logger.info(f"视频同步 MinIO 成功", video_url=final_url)
except Exception as dbe:
system_logger.error(f"更新视频记录失败: {str(dbe)}")
def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost, use_trial=True):
"""异步提交并查询视频任务状态"""
with app.app_context():
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 1. 提交任务
submit_resp = requests.post(get_proxied_url(Config.VIDEO_GEN_API), json=payload, headers=headers, timeout=Config.PROXY_TIMEOUT_DEFAULT)
if submit_resp.status_code != 200:
raise Exception(f"视频任务提交失败: {submit_resp.text}")
submit_result = submit_resp.json()
remote_task_id = submit_result.get('task_id')
if not remote_task_id:
raise Exception(f"未获取到远程任务 ID: {submit_result}")
# 2. 轮询状态
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "processing", "message": "视频生成中,请耐心等待..."}))
max_retries = 90 # 提升到 15 分钟
video_url = None
for i in range(max_retries):
# 更新进度 (伪进度或保持活跃)
if i % 2 == 0: # 每20秒更新一次心跳防止被认为是死任务
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({
"status": "processing",
"message": f"视频生成中 (已耗时 {i * 10} 秒)..."
}))
time.sleep(10)
poll_url = Config.VIDEO_POLL_API.format(task_id=remote_task_id)
poll_resp = requests.get(get_proxied_url(poll_url), headers=headers, timeout=Config.PROXY_TIMEOUT_SHORT)
if poll_resp.status_code != 200:
continue
poll_result = poll_resp.json()
status = poll_result.get('status', '').upper()
if status == 'SUCCESS':
# 提取视频输出地址
if 'data' in poll_result and isinstance(poll_result['data'], dict):
video_url = poll_result['data'].get('output')
if not video_url:
if 'data' in poll_result and isinstance(poll_result['data'], list) and poll_result['data']:
video_url = poll_result['data'][0].get('url')
elif 'video' in poll_result:
video_url = poll_result['video'].get('url') if isinstance(poll_result['video'], dict) else poll_result['video']
elif 'url' in poll_result:
video_url = poll_result['url']
break
elif status in ['FAILURE', 'FAILED', 'ERROR']:
raise Exception(f"视频生成失败: {_extract_error_detail(poll_result)}")
if not video_url:
if status in ['FAILURE', 'FAILED', 'ERROR']:
raise Exception(f"视频生成失败: {_extract_error_detail(poll_result)}")
raise Exception("超时未获取到视频地址")
# 3. 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
cost=cost,
image_urls=json.dumps([{"url": video_url, "type": "video"}])
)
db.session.add(new_record)
db.session.commit()
# 后台线程异步同步到 MinIO
threading.Thread(
target=sync_video_background,
args=(app, new_record.id, video_url, internal_task_id)
).start()
# 4. 存入 Redis
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "complete", "video_url": video_url, "record_id": new_record.id}))
system_logger.info(f"视频生成任务完成", user_id=user_id, task_id=internal_task_id)
except Exception as e:
system_logger.error(f"视频生成执行异常: {str(e)}", user_id=user_id, task_id=internal_task_id, prompt=payload.get('prompt'))
# 尝试退费
if use_trial:
try:
from services.generation_service import refund_points
refund_points(user_id, cost)
except Exception as re:
system_logger.error(f"退费失败: {str(re)}")
# 确保 Redis 状态一定被更新,防止前端死循环
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))