ai_v/services/task_service.py

283 lines
13 KiB
Python
Raw Normal View History

import os
import uuid
import json
import requests
import io
import time
import base64
import threading
from urllib.parse import quote
from PIL import Image
from extensions import s3_client, redis_client, db
from models import GenerationRecord, User
from config import Config
from services.logger import system_logger
def sync_images_background(app, record_id, raw_urls):
"""后台同步图片至 MinIO并生成缩略图带重试机制"""
with app.app_context():
processed_data = []
for raw_url in raw_urls:
success = False
for attempt in range(3): # 3 次重试机制
try:
img_resp = requests.get(raw_url, timeout=30)
if img_resp.status_code == 200:
content = img_resp.content
ext = ".png"
base_filename = f"gen-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
thumb_filename = f"{base_filename}-thumb{ext}"
# 1. 上传原图
s3_client.upload_fileobj(
io.BytesIO(content),
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": "image/png"}
)
full_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
thumb_url = full_url # 默认使用原图
# 2. 生成并上传缩略图 (400px 宽度)
try:
img = Image.open(io.BytesIO(content))
# 转换为 RGB 如果是 RGBA (避免某些格式保存问题)
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
# 缩放至宽度 400, 高度等比
w, h = img.size
if w > 400:
ratio = 400 / float(w)
img.thumbnail((400, int(h * ratio)), Image.Resampling.LANCZOS)
thumb_io = io.BytesIO()
# 缩略图保存为 JPEG 以获得更小的体积
img.save(thumb_io, format='JPEG', quality=80, optimize=True)
thumb_io.seek(0)
s3_client.upload_fileobj(
thumb_io,
Config.MINIO["bucket"],
thumb_filename.replace('.png', '.jpg'),
ExtraArgs={"ContentType": "image/jpeg"}
)
thumb_url = f"{Config.MINIO['public_url']}{quote(thumb_filename.replace('.png', '.jpg'))}"
except Exception as thumb_e:
print(f"⚠️ 缩略图生成失败: {thumb_e}")
processed_data.append({"url": full_url, "thumb": thumb_url})
success = True
break
except Exception as e:
print(f"⚠️ 第 {attempt+1} 次同步失败: {e}")
time.sleep(2 ** attempt) # 指数退避
if not success:
# 如果最终失败,保留原始 URL
processed_data.append({"url": raw_url, "thumb": raw_url})
# 更新数据库记录为持久化数据结构
try:
record = db.session.get(GenerationRecord, record_id)
if record:
record.image_urls = json.dumps(processed_data)
db.session.commit()
print(f"✅ 记录 {record_id} 图片及缩略图已完成同步")
except Exception as e:
print(f"❌ 更新记录失败: {e}")
def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost, use_trial=False):
"""异步执行图片生成并存入 Redis"""
with app.app_context():
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 使用较长的超时时间 (10分钟),确保长耗时任务不被中断
resp = requests.post(target_api, json=payload, headers=headers, timeout=1000)
if resp.status_code != 200:
if use_trial:
from services.generation_service import refund_points
refund_points(user_id, cost)
# 记录详细的失败上下文
system_logger.error(f"生图任务失败: {resp.text}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": resp.text}))
return
api_result = resp.json()
raw_urls = [item['url'] for item in api_result.get('data', [])]
# 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
cost=cost,
image_urls=json.dumps(raw_urls)
)
db.session.add(new_record)
db.session.commit()
# 后台线程处理:下载 AI 原始图片并同步到私有 MinIO
threading.Thread(
target=sync_images_background,
args=(app, new_record.id, raw_urls)
).start()
# 存入 Redis 标记完成
system_logger.info(f"生图任务完成", user_id=user_id, task_id=task_id, model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "complete", "urls": raw_urls}))
except Exception as e:
# 异常处理:退还积分
if use_trial:
from services.generation_service import refund_points
refund_points(user_id, cost)
system_logger.error(f"生图任务异常: {str(e)}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))
def sync_video_background(app, record_id, raw_url, internal_task_id=None):
"""后台同步视频至 MinIO带重试机制"""
with app.app_context():
success = False
final_url = raw_url
for attempt in range(3):
try:
# 增加了流式下载,处理大视频文件
with requests.get(raw_url, stream=True, timeout=120) as r:
r.raise_for_status()
content_type = r.headers.get('content-type', 'video/mp4')
ext = ".mp4"
if "text/html" in content_type: # 有些 API 返回的是跳转页面
continue
base_filename = f"video-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
video_io = io.BytesIO()
for chunk in r.iter_content(chunk_size=8192):
video_io.write(chunk)
video_io.seek(0)
# 上传至 MinIO
s3_client.upload_fileobj(
video_io,
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": content_type}
)
final_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
success = True
break
except Exception as e:
system_logger.error(f"同步视频失败 (第{attempt+1}次): {str(e)}")
time.sleep(5)
if success:
try:
record = db.session.get(GenerationRecord, record_id)
if record:
# 更新记录为 MinIO 的 URL
record.image_urls = json.dumps([{"url": final_url, "type": "video"}])
db.session.commit()
# 同步更新 Redis 中的缓存
if internal_task_id:
cached_data = redis_client.get(f"task:{internal_task_id}")
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode('utf-8')
task_info = json.loads(cached_data)
task_info['video_url'] = final_url
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps(task_info))
system_logger.info(f"视频同步 MinIO 成功", video_url=final_url)
except Exception as dbe:
system_logger.error(f"更新视频记录失败: {str(dbe)}")
def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost, use_trial=True):
"""异步提交并查询视频任务状态"""
with app.app_context():
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 1. 提交任务
submit_resp = requests.post(Config.VIDEO_GEN_API, json=payload, headers=headers, timeout=60)
if submit_resp.status_code != 200:
raise Exception(f"视频任务提交失败: {submit_resp.text}")
submit_result = submit_resp.json()
remote_task_id = submit_result.get('task_id')
if not remote_task_id:
raise Exception(f"未获取到远程任务 ID: {submit_result}")
# 2. 轮询状态
max_retries = 90 # 提升到 15 分钟
video_url = None
for i in range(max_retries):
time.sleep(10)
poll_url = Config.VIDEO_POLL_API.format(task_id=remote_task_id)
poll_resp = requests.get(poll_url, headers=headers, timeout=30)
if poll_resp.status_code != 200:
continue
poll_result = poll_resp.json()
status = poll_result.get('status', '').upper()
if status == 'SUCCESS':
# 提取视频输出地址
if 'data' in poll_result and isinstance(poll_result['data'], dict):
video_url = poll_result['data'].get('output')
if not video_url:
if 'data' in poll_result and isinstance(poll_result['data'], list) and poll_result['data']:
video_url = poll_result['data'][0].get('url')
elif 'video' in poll_result:
video_url = poll_result['video'].get('url') if isinstance(poll_result['video'], dict) else poll_result['video']
elif 'url' in poll_result:
video_url = poll_result['url']
break
elif status in ['FAILED', 'ERROR']:
raise Exception(f"视频生成失败: {poll_result.get('fail_reason') or poll_result.get('message') or '未知错误'}")
if not video_url:
raise Exception("超时未获取到视频地址")
# 3. 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
cost=cost,
image_urls=json.dumps([{"url": video_url, "type": "video"}])
)
db.session.add(new_record)
db.session.commit()
# 后台线程异步同步到 MinIO
threading.Thread(
target=sync_video_background,
args=(app, new_record.id, video_url, internal_task_id)
).start()
# 4. 存入 Redis
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "complete", "video_url": video_url, "record_id": new_record.id}))
system_logger.info(f"视频生成任务完成", user_id=user_id, task_id=internal_task_id)
except Exception as e:
system_logger.error(f"视频生成执行异常: {str(e)}", user_id=user_id, task_id=internal_task_id, prompt=payload.get('prompt'))
# 尝试退费
if use_trial:
try:
from services.generation_service import refund_points
refund_points(user_id, cost)
except Exception as re:
system_logger.error(f"退费失败: {str(re)}")
# 确保 Redis 状态一定被更新,防止前端死循环
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))