ai_v/blueprints/api.py

658 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
import json
import requests
import io
import threading
import time
import base64
from flask import Blueprint, request, jsonify, session, current_app
from urllib.parse import quote
from config import Config
from extensions import s3_client, redis_client, db
from models import GenerationRecord, User, SystemDict, SystemNotification
from middlewares.auth import login_required
from services.logger import system_logger
api_bp = Blueprint('api', __name__)
def sync_images_background(app, record_id, raw_urls):
"""后台同步图片至 MinIO并生成缩略图带重试机制"""
with app.app_context():
processed_data = []
for raw_url in raw_urls:
success = False
for attempt in range(3): # 3 次重试机制
try:
img_resp = requests.get(raw_url, timeout=30)
if img_resp.status_code == 200:
content = img_resp.content
ext = ".png"
base_filename = f"gen-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
thumb_filename = f"{base_filename}-thumb{ext}"
# 1. 上传原图
s3_client.upload_fileobj(
io.BytesIO(content),
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": "image/png"}
)
full_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
thumb_url = full_url # 默认使用原图
# 2. 生成并上传缩略图 (400px 宽度)
try:
from PIL import Image
img = Image.open(io.BytesIO(content))
# 转换为 RGB 如果是 RGBA (避免某些格式保存问题)
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
# 缩放至宽度 400, 高度等比
w, h = img.size
if w > 400:
ratio = 400 / float(w)
img.thumbnail((400, int(h * ratio)), Image.Resampling.LANCZOS)
thumb_io = io.BytesIO()
# 缩略图保存为 JPEG 以获得更小的体积
img.save(thumb_io, format='JPEG', quality=80, optimize=True)
thumb_io.seek(0)
s3_client.upload_fileobj(
thumb_io,
Config.MINIO["bucket"],
thumb_filename.replace('.png', '.jpg'),
ExtraArgs={"ContentType": "image/jpeg"}
)
thumb_url = f"{Config.MINIO['public_url']}{quote(thumb_filename.replace('.png', '.jpg'))}"
except Exception as thumb_e:
print(f"⚠️ 缩略图生成失败: {thumb_e}")
processed_data.append({"url": full_url, "thumb": thumb_url})
success = True
break
except Exception as e:
print(f"⚠️ 第 {attempt+1} 次同步失败: {e}")
time.sleep(2 ** attempt) # 指数退避
if not success:
# 如果最终失败,保留原始 URL
processed_data.append({"url": raw_url, "thumb": raw_url})
# 更新数据库记录为持久化数据结构
try:
record = db.session.get(GenerationRecord, record_id)
if record:
record.image_urls = json.dumps(processed_data)
db.session.commit()
print(f"✅ 记录 {record_id} 图片及缩略图已完成同步")
except Exception as e:
print(f"❌ 更新记录失败: {e}")
def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost):
"""异步执行图片生成并存入 Redis"""
with app.app_context():
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 使用较长的超时时间 (10分钟),确保长耗时任务不被中断
resp = requests.post(target_api, json=payload, headers=headers, timeout=1000)
if resp.status_code != 200:
user = db.session.get(User, user_id)
if user and "sk-" in api_key:
user.points += cost
db.session.commit()
system_logger.error(f"生图任务失败: {resp.text}", user_id=user_id, task_id=task_id)
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": resp.text}))
return
api_result = resp.json()
raw_urls = [item['url'] for item in api_result.get('data', [])]
# 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
image_urls=json.dumps(raw_urls)
)
db.session.add(new_record)
db.session.commit()
# 后台线程处理:下载 AI 原始图片并同步到私有 MinIO
threading.Thread(
target=sync_images_background,
args=(app, new_record.id, raw_urls)
).start()
# 存入 Redis 标记完成
system_logger.info(f"生图任务完成", user_id=user_id, task_id=task_id, model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "complete", "urls": raw_urls}))
except Exception as e:
# 异常处理:退还积分
user = db.session.get(User, user_id)
if user and "sk-" in api_key:
user.points += cost
db.session.commit()
system_logger.error(f"生图任务异常: {str(e)}", user_id=user_id, task_id=task_id)
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))
def sync_video_background(app, record_id, raw_url, internal_task_id=None):
"""后台同步视频至 MinIO带重试机制"""
with app.app_context():
success = False
final_url = raw_url
for attempt in range(3):
try:
# 增加了流式下载,处理大视频文件
with requests.get(raw_url, stream=True, timeout=120) as r:
r.raise_for_status()
content_type = r.headers.get('content-type', 'video/mp4')
ext = ".mp4"
if "text/html" in content_type: # 有些 API 返回的是跳转页面
continue
base_filename = f"video-{uuid.uuid4().hex}"
full_filename = f"{base_filename}{ext}"
video_io = io.BytesIO()
for chunk in r.iter_content(chunk_size=8192):
video_io.write(chunk)
video_io.seek(0)
# 上传至 MinIO
s3_client.upload_fileobj(
video_io,
Config.MINIO["bucket"],
full_filename,
ExtraArgs={"ContentType": content_type}
)
final_url = f"{Config.MINIO['public_url']}{quote(full_filename)}"
success = True
break
except Exception as e:
system_logger.error(f"同步视频失败 (第{attempt+1}次): {str(e)}")
time.sleep(5)
if success:
try:
record = db.session.get(GenerationRecord, record_id)
if record:
# 更新记录为 MinIO 的 URL
record.image_urls = json.dumps([{"url": final_url, "type": "video"}])
db.session.commit()
# 关键修复:同步更新 Redis 中的缓存,这样前端轮询也能拿到最新的 MinIO 地址
if internal_task_id:
cached_data = redis_client.get(f"task:{internal_task_id}")
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode('utf-8')
task_info = json.loads(cached_data)
task_info['video_url'] = final_url
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps(task_info))
system_logger.info(f"视频同步 MinIO 成功", video_url=final_url)
except Exception as dbe:
system_logger.error(f"更新视频记录失败: {str(dbe)}")
def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost):
"""异步提交并查询视频任务状态"""
with app.app_context():
try:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 1. 提交任务
submit_resp = requests.post(Config.VIDEO_GEN_API, json=payload, headers=headers, timeout=60)
if submit_resp.status_code != 200:
raise Exception(f"视频任务提交失败: {submit_resp.text}")
submit_result = submit_resp.json()
remote_task_id = submit_result.get('task_id')
if not remote_task_id:
raise Exception(f"未获取到远程任务 ID: {submit_result}")
# 2. 轮询状态
max_retries = 90 # 提升到 15 分钟
video_url = None
for i in range(max_retries):
time.sleep(10)
poll_url = Config.VIDEO_POLL_API.format(task_id=remote_task_id)
poll_resp = requests.get(poll_url, headers=headers, timeout=30)
if poll_resp.status_code != 200:
continue
poll_result = poll_resp.json()
status = poll_result.get('status', '').upper()
if status == 'SUCCESS':
# 提取视频输出地址
if 'data' in poll_result and isinstance(poll_result['data'], dict):
video_url = poll_result['data'].get('output')
if not video_url:
if 'data' in poll_result and isinstance(poll_result['data'], list) and poll_result['data']:
video_url = poll_result['data'][0].get('url')
elif 'video' in poll_result:
video_url = poll_result['video'].get('url') if isinstance(poll_result['video'], dict) else poll_result['video']
elif 'url' in poll_result:
video_url = poll_result['url']
break
elif status in ['FAILED', 'ERROR']:
raise Exception(f"视频生成失败: {poll_result.get('fail_reason') or poll_result.get('message') or '未知错误'}")
if not video_url:
raise Exception("超时未获取到视频地址")
# 3. 持久化记录
new_record = GenerationRecord(
user_id=user_id,
prompt=payload.get('prompt'),
model=payload.get('model'),
image_urls=json.dumps([{"url": video_url, "type": "video"}])
)
db.session.add(new_record)
db.session.commit()
# 后台线程异步同步到 MinIO
threading.Thread(
target=sync_video_background,
args=(app, new_record.id, video_url, internal_task_id)
).start()
# 4. 存入 Redis
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "complete", "video_url": video_url, "record_id": new_record.id}))
system_logger.info(f"视频生成任务完成", user_id=user_id, task_id=internal_task_id)
except Exception as e:
system_logger.error(f"视频生成执行异常: {str(e)}", user_id=user_id, task_id=internal_task_id)
# 尝试退费
try:
user = db.session.get(User, user_id)
if user:
user.points += cost
db.session.commit()
except Exception as re:
system_logger.error(f"退费失败: {str(re)}")
# 确保 Redis 状态一定被更新,防止前端死循环
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))
@api_bp.route('/api/task_status/<task_id>')
@login_required
def get_task_status(task_id):
"""查询异步任务状态"""
try:
data = redis_client.get(f"task:{task_id}")
if not data:
return jsonify({"status": "pending"})
# 兼容处理 bytes 和 str
if isinstance(data, bytes):
data = data.decode('utf-8')
return jsonify(json.loads(data))
except Exception as e:
system_logger.error(f"查询任务状态异常: {str(e)}")
return jsonify({"status": "error", "message": "状态查询失败"})
@api_bp.route('/api/config')
def get_config():
"""从本地数据库字典获取配置"""
try:
dicts = SystemDict.query.filter_by(is_active=True).order_by(SystemDict.sort_order.desc()).all()
config = {
"models": [],
"ratios": [],
"prompts": [],
"sizes": [],
"video_models": [],
"video_prompts": []
}
for d in dicts:
item = {"label": d.label, "value": d.value}
if d.dict_type == 'ai_model':
item["cost"] = d.cost
config["models"].append(item)
elif d.dict_type == 'aspect_ratio':
config["ratios"].append(item)
elif d.dict_type == 'prompt_tpl':
config["prompts"].append(item)
elif d.dict_type == 'ai_image_size':
config["sizes"].append(item)
elif d.dict_type == 'video_model':
item["cost"] = d.cost
config["video_models"].append(item)
elif d.dict_type == 'video_prompt':
config["video_prompts"].append(item)
return jsonify(config)
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/upload', methods=['POST'])
@login_required
def upload():
try:
files = request.files.getlist('images')
img_urls = []
for f in files:
ext = os.path.splitext(f.filename)[1]
filename = f"{uuid.uuid4().hex}{ext}"
s3_client.upload_fileobj(
f, Config.MINIO["bucket"], filename,
ExtraArgs={"ContentType": f.content_type}
)
img_urls.append(f"{Config.MINIO['public_url']}{quote(filename)}")
system_logger.info(f"用户上传文件: {len(files)}", user_id=session.get('user_id'))
return jsonify({"urls": img_urls})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/generate', methods=['POST'])
@login_required
def generate():
try:
user_id = session.get('user_id')
user = db.session.get(User, user_id)
data = request.json if request.is_json else request.form
mode = data.get('mode', 'trial')
is_premium = data.get('is_premium', False)
input_key = data.get('apiKey')
target_api = Config.AI_API
api_key = None
use_trial = False
if mode == 'key':
api_key = input_key or user.api_key
if not api_key:
return jsonify({"error": "请先输入您的 API 密钥"}), 400
else:
if user.points > 0:
api_key = Config.PREMIUM_KEY if is_premium else Config.TRIAL_KEY
target_api = Config.TRIAL_API
use_trial = True
else:
return jsonify({"error": "可用积分已耗尽,请充值或切换至自定义 Key 模式"}), 400
if mode == 'key' and input_key and input_key != user.api_key:
user.api_key = input_key
db.session.commit()
model_value = data.get('model')
is_chat_model = "gemini" in model_value.lower() or "gpt" in model_value.lower()
model_dict = SystemDict.query.filter_by(dict_type='ai_model', value=model_value).first()
cost = model_dict.cost if model_dict else 1
if use_trial and is_premium:
cost *= 2
if use_trial:
if user.points < cost:
return jsonify({"error": f"可用积分不足"}), 400
user.points -= cost
user.has_used_points = True # 标记已使用过积分
db.session.commit()
prompt = data.get('prompt')
ratio = data.get('ratio')
size = data.get('size')
image_data = data.get('image_data', [])
payload = {
"prompt": prompt,
"model": model_value,
"response_format": "url",
"aspect_ratio": ratio
}
if image_data:
payload["image"] = [img.split(',', 1)[1] if ',' in img else img for img in image_data]
if model_value == "nano-banana-2" and size:
payload["image_size"] = size
# 如果是聊天模型,直接同步处理
if is_chat_model:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
chat_payload = {
"model": model_value,
"messages": [{"role": "user", "content": prompt}]
}
resp = requests.post(Config.CHAT_API, json=chat_payload, headers=headers, timeout=120)
if resp.status_code != 200:
if use_trial:
user.points += cost
db.session.commit()
return jsonify({"error": resp.text}), resp.status_code
api_result = resp.json()
content = api_result['choices'][0]['message']['content']
# 记录聊天历史
if prompt != "解读验光单":
new_record = GenerationRecord(
user_id=user_id,
prompt=prompt,
model=model_value,
image_urls=json.dumps([{"type": "text", "content": content}])
)
db.session.add(new_record)
db.session.commit()
return jsonify({
"data": [{"content": content, "type": "text"}],
"message": "生成成功!"
})
# --- 异步处理生图任务 ---
task_id = str(uuid.uuid4())
app = current_app._get_current_object()
log_msg = "用户发起验光单解读" if prompt == "解读验光单" else "用户发起生图任务"
system_logger.info(log_msg, model=model_value, mode=mode)
threading.Thread(
target=process_image_generation,
args=(app, user_id, task_id, payload, api_key, target_api, cost)
).start()
return jsonify({
"task_id": task_id,
"message": "已开启异步生成任务"
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/video/generate', methods=['POST'])
@login_required
def video_generate():
try:
user_id = session.get('user_id')
user = db.session.get(User, user_id)
data = request.json
# 视频生成统一使用积分模式,隐藏 Key 模式
if user.points <= 0:
return jsonify({"error": "可用积分不足,请先充值"}), 400
model_value = data.get('model', 'veo3.1')
# 确定积分消耗 (优先从字典获取)
model_dict = SystemDict.query.filter_by(dict_type='video_model', value=model_value).first()
cost = model_dict.cost if model_dict else (15 if "pro" in model_value.lower() or "3.1" in model_value else 10)
if user.points < cost:
return jsonify({"error": f"积分不足,生成该视频需要 {cost} 积分"}), 400
# 扣除积分
user.points -= cost
user.has_used_points = True
db.session.commit()
# 构建符合 API 文档的 Payload
payload = {
"model": model_value,
"prompt": data.get('prompt'),
"enhance_prompt": data.get('enhance_prompt', False),
"images": data.get('images', []),
"aspect_ratio": data.get('aspect_ratio', '9:16')
}
# 使用系统内置的 Key
api_key = Config.TRIAL_KEY # 默认使用试用/中转 Key
task_id = str(uuid.uuid4())
app = current_app._get_current_object()
system_logger.info("用户发起视频生成任务 (积分模式)", model=model_value, cost=cost)
threading.Thread(
target=process_video_generation,
args=(app, user_id, task_id, payload, api_key, cost)
).start()
return jsonify({
"task_id": task_id,
"message": "视频生成任务已提交,系统正在导演中..."
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/notifications/latest', methods=['GET'])
@login_required
def get_latest_notification():
"""获取用户最近一条未读的激活通知"""
try:
user_id = session.get('user_id')
latest = SystemNotification.query.filter_by(is_active=True)\
.filter(~SystemNotification.read_by_users.any(id=user_id))\
.order_by(SystemNotification.created_at.desc()).first()
if latest:
return jsonify({
"id": latest.id,
"title": latest.title,
"content": latest.content
})
return jsonify({"id": None})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/notifications/read', methods=['POST'])
@login_required
def mark_notif_read():
"""将通知标记为已读"""
try:
user_id = session.get('user_id')
data = request.json
notif_id = data.get('id')
if not notif_id:
return jsonify({"error": "缺少通知 ID"}), 400
notif = db.session.get(SystemNotification, notif_id)
user = db.session.get(User, user_id)
if notif and user:
if user not in notif.read_by_users:
notif.read_by_users.append(user)
db.session.commit()
return jsonify({"status": "ok"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/history', methods=['GET'])
@login_required
def get_history():
"""获取用户的历史生成记录 (支持分页,限 90 天内)"""
try:
from datetime import datetime, timedelta
user_id = session.get('user_id')
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
# 计算 90 天前的时间
ninety_days_ago = datetime.utcnow() - timedelta(days=90)
pagination = GenerationRecord.query.filter(
GenerationRecord.user_id == user_id,
GenerationRecord.created_at >= ninety_days_ago,
GenerationRecord.prompt != "解读验光单" # 过滤掉验光单助手的操作记录
).order_by(GenerationRecord.created_at.desc())\
.paginate(page=page, per_page=per_page, error_out=False)
# 格式化 URL兼容新旧数据格式
history_list = []
for r in pagination.items:
raw_urls = json.loads(r.image_urls)
formatted_urls = []
for u in raw_urls:
if isinstance(u, str):
# 旧数据:直接返回原图作为缩略图
formatted_urls.append({"url": u, "thumb": u})
else:
# 如果是视频类型,提供默认预览图 (此处使用一个公共视频占位图或空)
if u.get('type') == 'video' and not u.get('thumb'):
u['thumb'] = "https://img.icons8.com/flat-round/64/000000/play--v1.png"
formatted_urls.append(u)
history_list.append({
"id": r.id,
"prompt": r.prompt,
"model": r.model,
"urls": formatted_urls,
"created_at": (r.created_at + timedelta(hours=8)).strftime('%b %d, %H:%M')
})
return jsonify({
"history": history_list,
"has_next": pagination.has_next,
"total": pagination.total
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/download_proxy', methods=['GET'])
@login_required
def download_proxy():
"""代理下载远程文件,强制浏览器弹出下载"""
url = request.args.get('url')
filename = request.args.get('filename', f"video-{int(time.time())}.mp4")
if not url:
return jsonify({"error": "缺少 URL 参数"}), 400
try:
# 流式获取远程文件
req = requests.get(url, stream=True, timeout=60)
req.raise_for_status()
headers = {}
if req.headers.get('Content-Type'):
headers['Content-Type'] = req.headers['Content-Type']
else:
headers['Content-Type'] = 'application/octet-stream'
headers['Content-Disposition'] = f'attachment; filename="{filename}"'
def generate():
for chunk in req.iter_content(chunk_size=4096):
yield chunk
return current_app.response_class(generate(), headers=headers)
except Exception as e:
system_logger.error(f"代理下载失败: {str(e)}")
return jsonify({"error": "下载失败"}), 500