261 lines
9.8 KiB
Python
261 lines
9.8 KiB
Python
|
|
import os
|
|||
|
|
import uuid
|
|||
|
|
import json
|
|||
|
|
import requests
|
|||
|
|
import io
|
|||
|
|
import threading
|
|||
|
|
import time
|
|||
|
|
from flask import Blueprint, request, jsonify, session, current_app
|
|||
|
|
from urllib.parse import quote
|
|||
|
|
from config import Config
|
|||
|
|
from extensions import s3_client, redis_client, db
|
|||
|
|
from models import GenerationRecord, User, SystemDict
|
|||
|
|
from middlewares.auth import login_required
|
|||
|
|
from services.logger import system_logger
|
|||
|
|
|
|||
|
|
api_bp = Blueprint('api', __name__)
|
|||
|
|
|
|||
|
|
def sync_images_background(app, record_id, raw_urls):
|
|||
|
|
"""后台同步图片至 MinIO,带重试机制"""
|
|||
|
|
with app.app_context():
|
|||
|
|
minio_urls = []
|
|||
|
|
for raw_url in raw_urls:
|
|||
|
|
success = False
|
|||
|
|
for attempt in range(3): # 3 次重试机制
|
|||
|
|
try:
|
|||
|
|
img_resp = requests.get(raw_url, timeout=30)
|
|||
|
|
if img_resp.status_code == 200:
|
|||
|
|
filename = f"gen-{uuid.uuid4().hex}.png"
|
|||
|
|
s3_client.upload_fileobj(
|
|||
|
|
io.BytesIO(img_resp.content),
|
|||
|
|
Config.MINIO["bucket"],
|
|||
|
|
filename,
|
|||
|
|
ExtraArgs={"ContentType": "image/png"}
|
|||
|
|
)
|
|||
|
|
minio_urls.append(f"{Config.MINIO['public_url']}{quote(filename)}")
|
|||
|
|
success = True
|
|||
|
|
break
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"⚠️ 第 {attempt+1} 次同步失败: {e}")
|
|||
|
|
time.sleep(2 ** attempt) # 指数退避
|
|||
|
|
|
|||
|
|
if not success:
|
|||
|
|
# 如果最终失败,保留原始 URL 以便至少有内容可看
|
|||
|
|
minio_urls.append(raw_url)
|
|||
|
|
|
|||
|
|
# 更新数据库记录为持久化 URL
|
|||
|
|
try:
|
|||
|
|
record = GenerationRecord.query.get(record_id)
|
|||
|
|
if record:
|
|||
|
|
record.image_urls = json.dumps(minio_urls)
|
|||
|
|
db.session.commit()
|
|||
|
|
print(f"✅ 记录 {record_id} 图片已完成持久化同步")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 更新记录失败: {e}")
|
|||
|
|
|
|||
|
|
@api_bp.route('/api/config')
|
|||
|
|
def get_config():
|
|||
|
|
"""从本地数据库字典获取配置"""
|
|||
|
|
try:
|
|||
|
|
dicts = SystemDict.query.filter_by(is_active=True).order_by(SystemDict.sort_order.desc()).all()
|
|||
|
|
|
|||
|
|
config = {
|
|||
|
|
"models": [],
|
|||
|
|
"ratios": [],
|
|||
|
|
"prompts": [],
|
|||
|
|
"sizes": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for d in dicts:
|
|||
|
|
item = {"label": d.label, "value": d.value}
|
|||
|
|
if d.dict_type == 'ai_model':
|
|||
|
|
item["cost"] = d.cost
|
|||
|
|
config["models"].append(item)
|
|||
|
|
elif d.dict_type == 'aspect_ratio':
|
|||
|
|
config["ratios"].append(item)
|
|||
|
|
elif d.dict_type == 'prompt_tpl':
|
|||
|
|
config["prompts"].append(item)
|
|||
|
|
elif d.dict_type == 'ai_image_size':
|
|||
|
|
config["sizes"].append(item)
|
|||
|
|
|
|||
|
|
return jsonify(config)
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
|
|||
|
|
@api_bp.route('/api/upload', methods=['POST'])
|
|||
|
|
@login_required
|
|||
|
|
def upload():
|
|||
|
|
try:
|
|||
|
|
files = request.files.getlist('images')
|
|||
|
|
img_urls = []
|
|||
|
|
for f in files:
|
|||
|
|
ext = os.path.splitext(f.filename)[1]
|
|||
|
|
filename = f"{uuid.uuid4().hex}{ext}"
|
|||
|
|
s3_client.upload_fileobj(
|
|||
|
|
f, Config.MINIO["bucket"], filename,
|
|||
|
|
ExtraArgs={"ContentType": f.content_type}
|
|||
|
|
)
|
|||
|
|
img_urls.append(f"{Config.MINIO['public_url']}{quote(filename)}")
|
|||
|
|
return jsonify({"urls": img_urls})
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
|
|||
|
|
@api_bp.route('/api/generate', methods=['POST'])
|
|||
|
|
@login_required
|
|||
|
|
def generate():
|
|||
|
|
try:
|
|||
|
|
user_id = session.get('user_id')
|
|||
|
|
user = User.query.get(user_id)
|
|||
|
|
|
|||
|
|
data = request.json if request.is_json else request.form
|
|||
|
|
mode = data.get('mode', 'trial')
|
|||
|
|
is_premium = data.get('is_premium', False)
|
|||
|
|
input_key = data.get('apiKey')
|
|||
|
|
|
|||
|
|
target_api = Config.AI_API
|
|||
|
|
api_key = None
|
|||
|
|
use_trial = False
|
|||
|
|
|
|||
|
|
if mode == 'key':
|
|||
|
|
# 自定义 Key 模式:优先使用本次输入的,否则使用数据库存的
|
|||
|
|
api_key = input_key or user.api_key
|
|||
|
|
if not api_key:
|
|||
|
|
return jsonify({"error": "请先输入您的 API 密钥"}), 400
|
|||
|
|
else:
|
|||
|
|
# 积分/试用模式
|
|||
|
|
if user.points > 0:
|
|||
|
|
# 核心修复:优质模式使用专属 Key
|
|||
|
|
api_key = Config.PREMIUM_KEY if is_premium else Config.TRIAL_KEY
|
|||
|
|
target_api = Config.TRIAL_API
|
|||
|
|
use_trial = True
|
|||
|
|
else:
|
|||
|
|
return jsonify({"error": "可用积分已耗尽,请充值或切换至自定义 Key 模式"}), 400
|
|||
|
|
|
|||
|
|
# 如果是 Key 模式且输入了新 Key,则自动更新到数据库保存
|
|||
|
|
if mode == 'key' and input_key and input_key != user.api_key:
|
|||
|
|
user.api_key = input_key
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
# 获取模型及对应的消耗积分
|
|||
|
|
model_value = data.get('model')
|
|||
|
|
model_dict = SystemDict.query.filter_by(dict_type='ai_model', value=model_value).first()
|
|||
|
|
cost = model_dict.cost if model_dict else 1
|
|||
|
|
|
|||
|
|
# 核心修复:优质模式积分消耗 X2
|
|||
|
|
if use_trial and is_premium:
|
|||
|
|
cost *= 2
|
|||
|
|
|
|||
|
|
# --- 积分预扣除逻辑 (点击即扣) ---
|
|||
|
|
if use_trial:
|
|||
|
|
if user.points < cost:
|
|||
|
|
return jsonify({"error": f"可用积分不足,优质模式需要 {cost} 积分,您当前剩余 {user.points} 积分"}), 400
|
|||
|
|
|
|||
|
|
user.points -= cost
|
|||
|
|
db.session.commit()
|
|||
|
|
system_logger.info(f"积分预扣除 ({'优质' if is_premium else '普通'}试用)", phone=user.phone, cost=cost, remaining_points=user.points)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
prompt = data.get('prompt')
|
|||
|
|
model = model_value
|
|||
|
|
ratio = data.get('ratio')
|
|||
|
|
size = data.get('size')
|
|||
|
|
input_img_urls = data.get('image_urls', [])
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"prompt": prompt,
|
|||
|
|
"model": model,
|
|||
|
|
"response_format": "url",
|
|||
|
|
"aspect_ratio": ratio,
|
|||
|
|
"image": input_img_urls
|
|||
|
|
}
|
|||
|
|
if model == "nano-banana-2" and size:
|
|||
|
|
payload["image_size"] = size
|
|||
|
|
|
|||
|
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|||
|
|
resp = requests.post(target_api, json=payload, headers=headers, timeout=300)
|
|||
|
|
|
|||
|
|
if resp.status_code != 200:
|
|||
|
|
# API 报错,退还积分
|
|||
|
|
if use_trial:
|
|||
|
|
user.points += cost
|
|||
|
|
db.session.commit()
|
|||
|
|
system_logger.warning(f"API 报错,积分已退还", phone=user.phone, status_code=resp.status_code)
|
|||
|
|
return jsonify({"error": resp.text}), resp.status_code
|
|||
|
|
|
|||
|
|
api_result = resp.json()
|
|||
|
|
raw_urls = [item['url'] for item in api_result.get('data', [])]
|
|||
|
|
|
|||
|
|
# 立即写入数据库(先存原始 URL)
|
|||
|
|
new_record = GenerationRecord(
|
|||
|
|
user_id=session.get('user_id'),
|
|||
|
|
prompt=prompt,
|
|||
|
|
model=model,
|
|||
|
|
image_urls=json.dumps(raw_urls)
|
|||
|
|
)
|
|||
|
|
db.session.add(new_record)
|
|||
|
|
db.session.commit()
|
|||
|
|
|
|||
|
|
# 写入系统日志
|
|||
|
|
system_logger.info(
|
|||
|
|
f"用户生成图片成功",
|
|||
|
|
phone=user.phone,
|
|||
|
|
model=model,
|
|||
|
|
record_id=new_record.id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 启动后台线程同步图片,不阻塞前端返回
|
|||
|
|
app = current_app._get_current_object()
|
|||
|
|
threading.Thread(
|
|||
|
|
target=sync_images_background,
|
|||
|
|
args=(app, new_record.id, raw_urls)
|
|||
|
|
).start()
|
|||
|
|
|
|||
|
|
# 立即返回原始 URL 给前端展示
|
|||
|
|
return jsonify({
|
|||
|
|
"data": [{"url": url} for url in raw_urls],
|
|||
|
|
"message": "生成成功!作品正在后台同步至云存储。"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
# 发生系统异常,退还积分
|
|||
|
|
if use_trial:
|
|||
|
|
user.points += cost
|
|||
|
|
db.session.commit()
|
|||
|
|
system_logger.error(f"生成异常,积分已退还", phone=user.phone, error=str(e))
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
|
|||
|
|
@api_bp.route('/api/history', methods=['GET'])
|
|||
|
|
@login_required
|
|||
|
|
def get_history():
|
|||
|
|
"""获取用户的历史生成记录 (支持分页,限 90 天内)"""
|
|||
|
|
try:
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
user_id = session.get('user_id')
|
|||
|
|
page = request.args.get('page', 1, type=int)
|
|||
|
|
per_page = request.args.get('per_page', 10, type=int)
|
|||
|
|
|
|||
|
|
# 计算 90 天前的时间
|
|||
|
|
ninety_days_ago = datetime.utcnow() - timedelta(days=90)
|
|||
|
|
|
|||
|
|
pagination = GenerationRecord.query.filter(
|
|||
|
|
GenerationRecord.user_id == user_id,
|
|||
|
|
GenerationRecord.created_at >= ninety_days_ago
|
|||
|
|
).order_by(GenerationRecord.created_at.desc())\
|
|||
|
|
.paginate(page=page, per_page=per_page, error_out=False)
|
|||
|
|
|
|||
|
|
return jsonify({
|
|||
|
|
"history": [{
|
|||
|
|
"id": r.id,
|
|||
|
|
"prompt": r.prompt,
|
|||
|
|
"model": r.model,
|
|||
|
|
"urls": json.loads(r.image_urls),
|
|||
|
|
"time": r.created_at.strftime('%Y-%m-%d %H:%M')
|
|||
|
|
} for r in pagination.items],
|
|||
|
|
"has_next": pagination.has_next,
|
|||
|
|
"total": pagination.total
|
|||
|
|
})
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|