ai_v/blueprints/api.py
24024 a47b84e009 feat(api): 实现图片生成异步任务与任务状态查询接口
- 新增异步图片生成处理函数,支持后台任务执行及积分退还机制
- 实现任务状态查询接口,支持前端实时获取生成进度和结果
- 优化生成逻辑:根据模型类型分流,聊天模型同步调用,图片模型异步执行
- 调整积分预扣除和退还逻辑,保障用户积分安全
- 后台线程同步图片至私有存储,提升响应性能和用户体验
- 新增 /visualizer 路由对应前端控制器页面,辅助3D构图和拍摄角度设置
- 优化前端上传逻辑,新增设置器模式时单图上传限制
- 移除项目中未使用的前端脚本与配置文件,简化代码库维护
2026-01-15 21:42:03 +08:00

402 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
import json
import requests
import io
import threading
import time
import base64
from flask import Blueprint, request, jsonify, session, current_app
from urllib.parse import quote
from config import Config
from extensions import s3_client, redis_client, db
from models import GenerationRecord, User, SystemDict, SystemNotification
from middlewares.auth import login_required
from services.logger import system_logger
api_bp = Blueprint('api', __name__)
def sync_images_background(app, record_id, raw_urls):
"""后台同步图片至 MinIO并生成缩略图带重试机制"""
with app.app_context():
processed_data = []
for raw_url in raw_urls:
success = False
for attempt in range(3): # 3 次重试机制
try:
img_resp = requests.get(raw_url, timeout=30)
if img_resp.status_code == 200:
content = img_resp.content
ext = ".png"
base_filename = f"gen-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
thumb_filename = f"{base_filename}-thumb{ext}"
# 1. 上传原图
s3_client.upload_fileobj(
io.BytesIO(content),
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": "image/png"}
)
full_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
thumb_url = full_url # 默认使用原图
# 2. 生成并上传缩略图 (400px 宽度)
try:
from PIL import Image
img = Image.open(io.BytesIO(content))
# 转换为 RGB 如果是 RGBA (避免某些格式保存问题)
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
# 缩放至宽度 400, 高度等比
w, h = img.size
if w > 400:
ratio = 400 / float(w)
img.thumbnail((400, int(h * ratio)), Image.Resampling.LANCZOS)
thumb_io = io.BytesIO()
# 缩略图保存为 JPEG 以获得更小的体积
img.save(thumb_io, format='JPEG', quality=80, optimize=True)
thumb_io.seek(0)
s3_client.upload_fileobj(
thumb_io,
Config.MINIO["bucket"],
thumb_filename.replace('.png', '.jpg'),
ExtraArgs={"ContentType": "image/jpeg"}
)
thumb_url = f"{Config.MINIO['public_url']}{quote(thumb_filename.replace('.png', '.jpg'))}"
except Exception as thumb_e:
print(f"⚠️ 缩略图生成失败: {thumb_e}")
processed_data.append({"url": full_url, "thumb": thumb_url})
success = True
break
except Exception as e:
print(f"⚠️ 第 {attempt+1} 次同步失败: {e}")
time.sleep(2 ** attempt) # 指数退避
if not success:
# 如果最终失败,保留原始 URL
processed_data.append({"url": raw_url, "thumb": raw_url})
# 更新数据库记录为持久化数据结构
try:
record = GenerationRecord.query.get(record_id)
if record:
record.image_urls = json.dumps(processed_data)
db.session.commit()
print(f"✅ 记录 {record_id} 图片及缩略图已完成同步")
except Exception as e:
print(f"❌ 更新记录失败: {e}")
def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost):
"""异步执行图片生成并存入 Redis"""
with app.app_context():
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 使用较长的超时时间 (10分钟),确保长耗时任务不被中断
resp = requests.post(target_api, json=payload, headers=headers, timeout=1000)
if resp.status_code != 200:
# 错误处理:退还积分
user = User.query.get(user_id)
if user and "sk-" in api_key:
user.points += cost
db.session.commit()
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": resp.text}))
return
api_result = resp.json()
raw_urls = [item['url'] for item in api_result.get('data', [])]
# 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
image_urls=json.dumps(raw_urls)
)
db.session.add(new_record)
db.session.commit()
# 后台线程处理:下载 AI 原始图片并同步到私有 MinIO
threading.Thread(
target=sync_images_background,
args=(app, new_record.id, raw_urls)
).start()
# 存入 Redis 标记完成
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "complete", "urls": raw_urls}))
except Exception as e:
# 异常处理:退还积分
user = User.query.get(user_id)
if user and "sk-" in api_key:
user.points += cost
db.session.commit()
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))
@api_bp.route('/api/task_status/<task_id>')
@login_required
def get_task_status(task_id):
"""查询异步任务状态"""
data = redis_client.get(f"task:{task_id}")
if not data:
# 如果 Redis 里没有,可能是刚提交,也可能是过期了
return jsonify({"status": "pending"})
return jsonify(json.loads(data))
@api_bp.route('/api/config')
def get_config():
"""从本地数据库字典获取配置"""
try:
dicts = SystemDict.query.filter_by(is_active=True).order_by(SystemDict.sort_order.desc()).all()
config = {
"models": [],
"ratios": [],
"prompts": [],
"sizes": []
}
for d in dicts:
item = {"label": d.label, "value": d.value}
if d.dict_type == 'ai_model':
item["cost"] = d.cost
config["models"].append(item)
elif d.dict_type == 'aspect_ratio':
config["ratios"].append(item)
elif d.dict_type == 'prompt_tpl':
config["prompts"].append(item)
elif d.dict_type == 'ai_image_size':
config["sizes"].append(item)
return jsonify(config)
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/upload', methods=['POST'])
@login_required
def upload():
try:
files = request.files.getlist('images')
img_urls = []
for f in files:
ext = os.path.splitext(f.filename)[1]
filename = f"{uuid.uuid4().hex}{ext}"
s3_client.upload_fileobj(
f, Config.MINIO["bucket"], filename,
ExtraArgs={"ContentType": f.content_type}
)
img_urls.append(f"{Config.MINIO['public_url']}{quote(filename)}")
return jsonify({"urls": img_urls})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/generate', methods=['POST'])
@login_required
def generate():
try:
user_id = session.get('user_id')
user = User.query.get(user_id)
data = request.json if request.is_json else request.form
mode = data.get('mode', 'trial')
is_premium = data.get('is_premium', False)
input_key = data.get('apiKey')
target_api = Config.AI_API
api_key = None
use_trial = False
if mode == 'key':
api_key = input_key or user.api_key
if not api_key:
return jsonify({"error": "请先输入您的 API 密钥"}), 400
else:
if user.points > 0:
api_key = Config.PREMIUM_KEY if is_premium else Config.TRIAL_KEY
target_api = Config.TRIAL_API
use_trial = True
else:
return jsonify({"error": "可用积分已耗尽,请充值或切换至自定义 Key 模式"}), 400
if mode == 'key' and input_key and input_key != user.api_key:
user.api_key = input_key
db.session.commit()
model_value = data.get('model')
is_chat_model = "gemini" in model_value.lower() or "gpt" in model_value.lower()
model_dict = SystemDict.query.filter_by(dict_type='ai_model', value=model_value).first()
cost = model_dict.cost if model_dict else 1
if use_trial and is_premium:
cost *= 2
if use_trial:
if user.points < cost:
return jsonify({"error": f"可用积分不足"}), 400
user.points -= cost
db.session.commit()
prompt = data.get('prompt')
ratio = data.get('ratio')
size = data.get('size')
image_data = data.get('image_data', [])
payload = {
"prompt": prompt,
"model": model_value,
"response_format": "url",
"aspect_ratio": ratio
}
if image_data:
payload["image"] = [img.split(',', 1)[1] if ',' in img else img for img in image_data]
if model_value == "nano-banana-2" and size:
payload["image_size"] = size
# 如果是聊天模型,直接同步处理
if is_chat_model:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
chat_payload = {
"model": model_value,
"messages": [{"role": "user", "content": prompt}]
}
resp = requests.post(Config.CHAT_API, json=chat_payload, headers=headers, timeout=120)
if resp.status_code != 200:
if use_trial:
user.points += cost
db.session.commit()
return jsonify({"error": resp.text}), resp.status_code
api_result = resp.json()
content = api_result['choices'][0]['message']['content']
# 记录聊天历史
if prompt != "解读验光单":
new_record = GenerationRecord(
user_id=user_id,
prompt=prompt,
model=model_value,
image_urls=json.dumps([{"type": "text", "content": content}])
)
db.session.add(new_record)
db.session.commit()
return jsonify({
"data": [{"content": content, "type": "text"}],
"message": "生成成功!"
})
# --- 异步处理生图任务 ---
task_id = str(uuid.uuid4())
app = current_app._get_current_object()
threading.Thread(
target=process_image_generation,
args=(app, user_id, task_id, payload, api_key, target_api, cost)
).start()
return jsonify({
"task_id": task_id,
"message": "已开启异步生成任务"
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/notifications/latest', methods=['GET'])
@login_required
def get_latest_notification():
"""获取用户最近一条未读的激活通知"""
try:
user_id = session.get('user_id')
latest = SystemNotification.query.filter_by(is_active=True)\
.filter(~SystemNotification.read_by_users.any(id=user_id))\
.order_by(SystemNotification.created_at.desc()).first()
if latest:
return jsonify({
"id": latest.id,
"title": latest.title,
"content": latest.content
})
return jsonify({"id": None})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/notifications/read', methods=['POST'])
@login_required
def mark_notif_read():
"""将通知标记为已读"""
try:
user_id = session.get('user_id')
data = request.json
notif_id = data.get('id')
if not notif_id:
return jsonify({"error": "缺少通知 ID"}), 400
notif = SystemNotification.query.get(notif_id)
user = User.query.get(user_id)
if notif and user:
if user not in notif.read_by_users:
notif.read_by_users.append(user)
db.session.commit()
return jsonify({"status": "ok"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/history', methods=['GET'])
@login_required
def get_history():
"""获取用户的历史生成记录 (支持分页,限 90 天内)"""
try:
from datetime import datetime, timedelta
user_id = session.get('user_id')
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
# 计算 90 天前的时间
ninety_days_ago = datetime.utcnow() - timedelta(days=90)
pagination = GenerationRecord.query.filter(
GenerationRecord.user_id == user_id,
GenerationRecord.created_at >= ninety_days_ago,
GenerationRecord.prompt != "解读验光单" # 过滤掉验光单助手的操作记录
).order_by(GenerationRecord.created_at.desc())\
.paginate(page=page, per_page=per_page, error_out=False)
# 格式化 URL兼容新旧数据格式
history_list = []
for r in pagination.items:
raw_urls = json.loads(r.image_urls)
formatted_urls = []
for u in raw_urls:
if isinstance(u, str):
# 旧数据:直接返回原图作为缩略图
formatted_urls.append({"url": u, "thumb": u})
else:
# 新数据:包含 url 和 thumb
formatted_urls.append(u)
history_list.append({
"id": r.id,
"model": r.model,
"urls": formatted_urls,
"time": (r.created_at + timedelta(hours=8)).strftime('%Y-%m-%d %H:%M')
})
return jsonify({
"history": history_list,
"has_next": pagination.has_next,
"total": pagination.total
})
except Exception as e:
return jsonify({"error": str(e)}), 500