diff --git a/blueprints/api.py b/blueprints/api.py index 72fee15..1ea786e 100644 --- a/blueprints/api.py +++ b/blueprints/api.py @@ -68,7 +68,7 @@ def generate(): # 2. 扣除积分 (如果是试用模式) if use_trial: - deduct_points(user, cost) + deduct_points(user_id, cost) model_value = data.get('model') prompt = data.get('prompt') @@ -84,7 +84,6 @@ def generate(): "prompt": prompt, "model": model_value, "response_format": "url", - "n": int(data.get('n', 1)), # 添加数量 "aspect_ratio": data.get('ratio') } image_data = data.get('image_data', []) @@ -120,7 +119,7 @@ def video_generate(): return jsonify({"error": error}), 400 # 2. 扣除积分 - deduct_points(user, cost) + deduct_points(user_id, cost) # 3. 构造 Payload payload = { diff --git a/services/generation_service.py b/services/generation_service.py index d194ede..821a9c9 100644 --- a/services/generation_service.py +++ b/services/generation_service.py @@ -28,7 +28,6 @@ def validate_generation_request(user, data): is_premium = data.get('is_premium', False) input_key = data.get('apiKey') model_value = data.get('model') - n = int(data.get('n', 1)) # 获取请求中的数量 target_api = Config.AI_API api_key = None @@ -52,12 +51,9 @@ def validate_generation_request(user, data): return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式" # 计算单价 - base_cost = get_model_cost(model_value, is_video=False) + cost = get_model_cost(model_value, is_video=False) if use_trial and is_premium: - base_cost *= 2 - - # 总消耗 = 单价 * 数量 - cost = base_cost * n + cost *= 2 if use_trial: if user.points < cost: @@ -65,21 +61,23 @@ def validate_generation_request(user, data): return api_key, target_api, cost, use_trial, None -def deduct_points(user, cost): - """扣除积分""" - user.points -= cost - user.has_used_points = True - db.session.commit() +def deduct_points(user_id, cost): + """原子扣除积分""" + user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first() + if user: + user.points -= cost + user.has_used_points = True + db.session.commit() def refund_points(user_id, cost): - """退还积分""" + """原子退还积分""" try: - user = db.session.get(User, user_id) + user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first() if user: user.points += cost db.session.commit() except: - pass + db.session.rollback() def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost): """同步处理对话类模型""" diff --git a/static/js/main.js b/static/js/main.js index 0a28ece..fb6b1fb 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -503,103 +503,86 @@ document.getElementById('submitBtn').onclick = async () => { image_data = await Promise.all(uploadedFiles.map(f => readFileAsBase64(f))); } - // 2. 发起单次生成请求 (包含数量 n) - btnText.innerText = `AI 构思中...`; + // 2. 并行启动多个生成任务 + btnText.innerText = `AI 构思中 (0/${num})...`; + let finishedCount = 0; - // 预先创建 Slot (对应数量) - const slots = []; - for (let i = 0; i < num; i++) { + const startTask = async (index) => { const slot = document.createElement('div'); slot.className = 'image-frame relative bg-white/50 animate-pulse min-h-[200px] flex items-center justify-center rounded-[2.5rem] border border-slate-100 shadow-sm'; slot.innerHTML = `
正在排队中...
`; grid.appendChild(slot); - slots.push(slot); - } - // 3. 提交任务 - const r = await fetch('/api/generate', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - mode: currentMode, - n: num, // 发送数量 - is_premium: document.getElementById('isPremium')?.checked || false, - apiKey: currentMode === 'key' ? apiKey : '', - prompt: document.getElementById('manualPrompt').value, - model: document.getElementById('modelSelect').value, - ratio: document.getElementById('ratioSelect').value, - size: document.getElementById('sizeSelect').value, - image_data // 发送 Base64 数组 - }) - }); - const res = await r.json(); - if (res.error) throw new Error(res.error); - - // 如果直接返回了 data (比如聊天模型),直接显示在第一个 slot - if (res.data) { - displayResult(slots[0], res.data[0]); - // 移除其他多余的 slots - for (let i = 1; i < slots.length; i++) slots[i].remove(); - return; - } - - // 4. 轮询任务状态 - const taskId = res.task_id; - let pollCount = 0; - const maxPolls = 500; - - while (pollCount < maxPolls) { - await new Promise(resolve => setTimeout(resolve, 2000)); - pollCount++; - - const statusR = await fetch(`/api/task_status/${taskId}`); - const statusRes = await statusR.json(); - - if (statusRes.status === 'complete') { - const urls = statusRes.urls; - currentGeneratedUrls = urls; - - // 填充所有 slots - urls.forEach((url, idx) => { - if (slots[idx]) { - displayResult(slots[idx], { url }); - } + try { + // 1. 发起生成请求,获取任务 ID + const r = await fetch('/api/generate', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + mode: currentMode, + is_premium: document.getElementById('isPremium')?.checked || false, + apiKey: currentMode === 'key' ? apiKey : '', + prompt: document.getElementById('manualPrompt').value, + model: document.getElementById('modelSelect').value, + ratio: document.getElementById('ratioSelect').value, + size: document.getElementById('sizeSelect').value, + image_data // 发送 Base64 数组 + }) }); + const res = await r.json(); + if (res.error) throw new Error(res.error); - // 如果返回的数量少于 slots 数量,移除多余的 - if (urls.length < slots.length) { - for (let i = urls.length; i < slots.length; i++) slots[i].remove(); + // 如果直接返回了 data (比如聊天模型),直接显示 + if (res.data) { + displayResult(slot, res.data[0]); + return; } - btnText.innerText = `生成完成`; - if (currentMode === 'trial') checkAuth(); - return; - } else if (statusRes.status === 'error') { - throw new Error(statusRes.message || "生成失败"); - } else { - // 更新所有 slot 的轮询状态显示 - slots.forEach(slot => { - slot.innerHTML = `
AI 正在努力创作中 (${pollCount * 2}s)...
`; - }); + // 2. 轮询任务状态 + const taskId = res.task_id; + let pollCount = 0; + const maxPolls = 500; // 最多轮询约 16 分钟 (2s * 500 = 1000s) + + while (pollCount < maxPolls) { + await new Promise(resolve => setTimeout(resolve, 2000)); + pollCount++; + + const statusR = await fetch(`/api/task_status/${taskId}`); + const statusRes = await statusR.json(); + + if (statusRes.status === 'complete') { + const imgUrl = statusRes.urls[0]; + currentGeneratedUrls.push(imgUrl); + displayResult(slot, { url: imgUrl }); + finishedCount++; + btnText.innerText = `AI 构思中 (${finishedCount}/${num})...`; + if (currentMode === 'trial') checkAuth(); + return; // 任务正常结束 + } else if (statusRes.status === 'error') { + throw new Error(statusRes.message || "生成失败"); + } else { + // 继续轮询状态显示 + slot.innerHTML = `
AI 正在努力创作中 (${pollCount * 2}s)...
`; + } + } + throw new Error("生成超时,请稍后在历史记录中查看"); + + } catch (e) { + slot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6'; + if (e.message.includes('401') || e.message.includes('请先登录')) { + slot.innerHTML = `
登录已过期,请重新登录
`; + } else { + slot.innerHTML = `
生成异常: ${e.message}
`; + } } - } - throw new Error("生成超时,请稍后在历史记录中查看"); + }; + + const tasks = Array.from({ length: num }, (_, i) => startTask(i)); + await Promise.all(tasks); } catch (e) { - // 在第一个 slot 显示错误 - if (grid.firstChild) { - const firstSlot = grid.firstChild; - firstSlot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6'; - if (e.message.includes('401') || e.message.includes('请先登录')) { - firstSlot.innerHTML = `
登录已过期,请重新登录
`; - } else { - firstSlot.innerHTML = `
生成异常: ${e.message}
`; - } - // 移除其他 slots - while (grid.children.length > 1) { - grid.lastChild.remove(); - } - } + showToast('创作引擎中断: ' + e.message, 'error'); + document.getElementById('placeholder').classList.remove('hidden'); } finally { btn.disabled = false; btnText.innerText = "立即生成作品";