feat: Implement AI image, video, and chat generation features with integrated point and API key management.

This commit is contained in:
公司git 2026-01-20 09:53:06 +08:00
parent 824508f6a4
commit 1202291e4b
3 changed files with 81 additions and 101 deletions

View File

@ -68,7 +68,7 @@ def generate():
# 2. 扣除积分 (如果是试用模式) # 2. 扣除积分 (如果是试用模式)
if use_trial: if use_trial:
deduct_points(user, cost) deduct_points(user_id, cost)
model_value = data.get('model') model_value = data.get('model')
prompt = data.get('prompt') prompt = data.get('prompt')
@ -84,7 +84,6 @@ def generate():
"prompt": prompt, "prompt": prompt,
"model": model_value, "model": model_value,
"response_format": "url", "response_format": "url",
"n": int(data.get('n', 1)), # 添加数量
"aspect_ratio": data.get('ratio') "aspect_ratio": data.get('ratio')
} }
image_data = data.get('image_data', []) image_data = data.get('image_data', [])
@ -120,7 +119,7 @@ def video_generate():
return jsonify({"error": error}), 400 return jsonify({"error": error}), 400
# 2. 扣除积分 # 2. 扣除积分
deduct_points(user, cost) deduct_points(user_id, cost)
# 3. 构造 Payload # 3. 构造 Payload
payload = { payload = {

View File

@ -28,7 +28,6 @@ def validate_generation_request(user, data):
is_premium = data.get('is_premium', False) is_premium = data.get('is_premium', False)
input_key = data.get('apiKey') input_key = data.get('apiKey')
model_value = data.get('model') model_value = data.get('model')
n = int(data.get('n', 1)) # 获取请求中的数量
target_api = Config.AI_API target_api = Config.AI_API
api_key = None api_key = None
@ -52,12 +51,9 @@ def validate_generation_request(user, data):
return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式" return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式"
# 计算单价 # 计算单价
base_cost = get_model_cost(model_value, is_video=False) cost = get_model_cost(model_value, is_video=False)
if use_trial and is_premium: if use_trial and is_premium:
base_cost *= 2 cost *= 2
# 总消耗 = 单价 * 数量
cost = base_cost * n
if use_trial: if use_trial:
if user.points < cost: if user.points < cost:
@ -65,21 +61,23 @@ def validate_generation_request(user, data):
return api_key, target_api, cost, use_trial, None return api_key, target_api, cost, use_trial, None
def deduct_points(user, cost): def deduct_points(user_id, cost):
"""扣除积分""" """原子扣除积分"""
user.points -= cost user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first()
user.has_used_points = True if user:
db.session.commit() user.points -= cost
user.has_used_points = True
db.session.commit()
def refund_points(user_id, cost): def refund_points(user_id, cost):
"""退还积分""" """原子退还积分"""
try: try:
user = db.session.get(User, user_id) user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first()
if user: if user:
user.points += cost user.points += cost
db.session.commit() db.session.commit()
except: except:
pass db.session.rollback()
def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost): def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost):
"""同步处理对话类模型""" """同步处理对话类模型"""

View File

@ -503,103 +503,86 @@ document.getElementById('submitBtn').onclick = async () => {
image_data = await Promise.all(uploadedFiles.map(f => readFileAsBase64(f))); image_data = await Promise.all(uploadedFiles.map(f => readFileAsBase64(f)));
} }
// 2. 发起单次生成请求 (包含数量 n) // 2. 并行启动多个生成任务
btnText.innerText = `AI 构思中...`; btnText.innerText = `AI 构思中 (0/${num})...`;
let finishedCount = 0;
// 预先创建 Slot (对应数量) const startTask = async (index) => {
const slots = [];
for (let i = 0; i < num; i++) {
const slot = document.createElement('div'); const slot = document.createElement('div');
slot.className = 'image-frame relative bg-white/50 animate-pulse min-h-[200px] flex items-center justify-center rounded-[2.5rem] border border-slate-100 shadow-sm'; slot.className = 'image-frame relative bg-white/50 animate-pulse min-h-[200px] flex items-center justify-center rounded-[2.5rem] border border-slate-100 shadow-sm';
slot.innerHTML = `<div class="text-slate-400 text-[10px] font-bold italic">正在排队中...</div>`; slot.innerHTML = `<div class="text-slate-400 text-[10px] font-bold italic">正在排队中...</div>`;
grid.appendChild(slot); grid.appendChild(slot);
slots.push(slot);
}
// 3. 提交任务 try {
const r = await fetch('/api/generate', { // 1. 发起生成请求,获取任务 ID
method: 'POST', const r = await fetch('/api/generate', {
headers: { 'Content-Type': 'application/json' }, method: 'POST',
body: JSON.stringify({ headers: { 'Content-Type': 'application/json' },
mode: currentMode, body: JSON.stringify({
n: num, // 发送数量 mode: currentMode,
is_premium: document.getElementById('isPremium')?.checked || false, is_premium: document.getElementById('isPremium')?.checked || false,
apiKey: currentMode === 'key' ? apiKey : '', apiKey: currentMode === 'key' ? apiKey : '',
prompt: document.getElementById('manualPrompt').value, prompt: document.getElementById('manualPrompt').value,
model: document.getElementById('modelSelect').value, model: document.getElementById('modelSelect').value,
ratio: document.getElementById('ratioSelect').value, ratio: document.getElementById('ratioSelect').value,
size: document.getElementById('sizeSelect').value, size: document.getElementById('sizeSelect').value,
image_data // 发送 Base64 数组 image_data // 发送 Base64 数组
}) })
});
const res = await r.json();
if (res.error) throw new Error(res.error);
// 如果直接返回了 data (比如聊天模型),直接显示在第一个 slot
if (res.data) {
displayResult(slots[0], res.data[0]);
// 移除其他多余的 slots
for (let i = 1; i < slots.length; i++) slots[i].remove();
return;
}
// 4. 轮询任务状态
const taskId = res.task_id;
let pollCount = 0;
const maxPolls = 500;
while (pollCount < maxPolls) {
await new Promise(resolve => setTimeout(resolve, 2000));
pollCount++;
const statusR = await fetch(`/api/task_status/${taskId}`);
const statusRes = await statusR.json();
if (statusRes.status === 'complete') {
const urls = statusRes.urls;
currentGeneratedUrls = urls;
// 填充所有 slots
urls.forEach((url, idx) => {
if (slots[idx]) {
displayResult(slots[idx], { url });
}
}); });
const res = await r.json();
if (res.error) throw new Error(res.error);
// 如果返回的数量少于 slots 数量,移除多余的 // 如果直接返回了 data (比如聊天模型),直接显示
if (urls.length < slots.length) { if (res.data) {
for (let i = urls.length; i < slots.length; i++) slots[i].remove(); displayResult(slot, res.data[0]);
return;
} }
btnText.innerText = `生成完成`; // 2. 轮询任务状态
if (currentMode === 'trial') checkAuth(); const taskId = res.task_id;
return; let pollCount = 0;
} else if (statusRes.status === 'error') { const maxPolls = 500; // 最多轮询约 16 分钟 (2s * 500 = 1000s)
throw new Error(statusRes.message || "生成失败");
} else { while (pollCount < maxPolls) {
// 更新所有 slot 的轮询状态显示 await new Promise(resolve => setTimeout(resolve, 2000));
slots.forEach(slot => { pollCount++;
slot.innerHTML = `<div class="text-slate-400 text-[10px] font-bold italic">AI 正在努力创作中 (${pollCount * 2}s)...</div>`;
}); const statusR = await fetch(`/api/task_status/${taskId}`);
const statusRes = await statusR.json();
if (statusRes.status === 'complete') {
const imgUrl = statusRes.urls[0];
currentGeneratedUrls.push(imgUrl);
displayResult(slot, { url: imgUrl });
finishedCount++;
btnText.innerText = `AI 构思中 (${finishedCount}/${num})...`;
if (currentMode === 'trial') checkAuth();
return; // 任务正常结束
} else if (statusRes.status === 'error') {
throw new Error(statusRes.message || "生成失败");
} else {
// 继续轮询状态显示
slot.innerHTML = `<div class="text-slate-400 text-[10px] font-bold italic">AI 正在努力创作中 (${pollCount * 2}s)...</div>`;
}
}
throw new Error("生成超时,请稍后在历史记录中查看");
} catch (e) {
slot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6';
if (e.message.includes('401') || e.message.includes('请先登录')) {
slot.innerHTML = `<div class="text-red-400 text-[10px] font-bold text-center">登录已过期,请重新登录</div>`;
} else {
slot.innerHTML = `<div class="text-red-400 text-[10px] font-bold text-center">生成异常: ${e.message}</div>`;
}
} }
} };
throw new Error("生成超时,请稍后在历史记录中查看");
const tasks = Array.from({ length: num }, (_, i) => startTask(i));
await Promise.all(tasks);
} catch (e) { } catch (e) {
// 在第一个 slot 显示错误 showToast('创作引擎中断: ' + e.message, 'error');
if (grid.firstChild) { document.getElementById('placeholder').classList.remove('hidden');
const firstSlot = grid.firstChild;
firstSlot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6';
if (e.message.includes('401') || e.message.includes('请先登录')) {
firstSlot.innerHTML = `<div class="text-red-400 text-[10px] font-bold text-center">登录已过期,请重新登录</div>`;
} else {
firstSlot.innerHTML = `<div class="text-red-400 text-[10px] font-bold text-center">生成异常: ${e.message}</div>`;
}
// 移除其他 slots
while (grid.children.length > 1) {
grid.lastChild.remove();
}
}
} finally { } finally {
btn.disabled = false; btn.disabled = false;
btnText.innerText = "立即生成作品"; btnText.innerText = "立即生成作品";