From 824508f6a4232b8a10945ebd4f0e4a0fb491b847 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=85=AC=E5=8F=B8git?= <240241002@qq.com>
Date: Tue, 20 Jan 2026 09:29:01 +0800
Subject: [PATCH 1/3] feat: implement core API and generation services for AI
image, video, and chat functionalities, user history, and point management.
---
README.md | 8 ++
blueprints/api.py | 1 +
services/generation_service.py | 11 +-
static/js/main.js | 194 ++++++++++++++++++---------------
4 files changed, 125 insertions(+), 89 deletions(-)
diff --git a/README.md b/README.md
index 52f32f1..71dc125 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,14 @@
### 1. 环境准备
确保已安装 Python 3.8+ 和 PostgreSQL / Redis。
+**创建虚拟环境**
+```bash
+# 如果 python 命令不可用,请尝试使用 py
+python -m venv .venv
+# 或者
+py -m venv .venv
+```
+
**激活虚拟环境** (推荐)
```bash
# Windows (PowerShell)
diff --git a/blueprints/api.py b/blueprints/api.py
index b556c52..72fee15 100644
--- a/blueprints/api.py
+++ b/blueprints/api.py
@@ -84,6 +84,7 @@ def generate():
"prompt": prompt,
"model": model_value,
"response_format": "url",
+ "n": int(data.get('n', 1)), # 添加数量
"aspect_ratio": data.get('ratio')
}
image_data = data.get('image_data', [])
diff --git a/services/generation_service.py b/services/generation_service.py
index 1276b69..d194ede 100644
--- a/services/generation_service.py
+++ b/services/generation_service.py
@@ -28,6 +28,7 @@ def validate_generation_request(user, data):
is_premium = data.get('is_premium', False)
input_key = data.get('apiKey')
model_value = data.get('model')
+ n = int(data.get('n', 1)) # 获取请求中的数量
target_api = Config.AI_API
api_key = None
@@ -50,13 +51,17 @@ def validate_generation_request(user, data):
else:
return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式"
- cost = get_model_cost(model_value, is_video=False)
+ # 计算单价
+ base_cost = get_model_cost(model_value, is_video=False)
if use_trial and is_premium:
- cost *= 2
+ base_cost *= 2
+
+ # 总消耗 = 单价 * 数量
+ cost = base_cost * n
if use_trial:
if user.points < cost:
- return None, None, cost, True, "可用积分不足"
+ return None, None, cost, True, f"可用积分不足(本次需要 {cost} 积分)"
return api_key, target_api, cost, use_trial, None
diff --git a/static/js/main.js b/static/js/main.js
index 2093565..0a28ece 100644
--- a/static/js/main.js
+++ b/static/js/main.js
@@ -195,6 +195,8 @@ async function init() {
if (modeTrialBtn) modeTrialBtn.onclick = () => switchMode('trial');
if (modeKeyBtn) modeKeyBtn.onclick = () => switchMode('key');
if (isPremiumCheckbox) isPremiumCheckbox.onchange = () => updateCostPreview();
+ const numSelect = document.getElementById('numSelect');
+ if (numSelect) numSelect.onchange = () => updateCostPreview();
// 历史记录控制
const historyDrawer = document.getElementById('historyDrawer');
@@ -305,13 +307,16 @@ function fillSelect(id, list) {
function updateCostPreview() {
const modelSelect = document.getElementById('modelSelect');
const costPreview = document.getElementById('costPreview');
+ const numSelect = document.getElementById('numSelect');
const isPremium = document.getElementById('isPremium')?.checked || false;
const selectedOption = modelSelect.options[modelSelect.selectedIndex];
if (currentMode === 'trial' && selectedOption) {
- let cost = parseInt(selectedOption.getAttribute('data-cost') || 0);
- if (isPremium) cost *= 2; // 优质模式 2 倍积分
- costPreview.innerText = `本次生成将消耗 ${cost} 积分`;
+ let baseCost = parseInt(selectedOption.getAttribute('data-cost') || 0);
+ let num = parseInt(numSelect?.value || 1);
+ let totalCost = baseCost * num;
+ if (isPremium) totalCost *= 2; // 优质模式 2 倍积分
+ costPreview.innerText = `本次生成将消耗 ${totalCost} 积分`;
costPreview.classList.remove('hidden');
} else {
costPreview.classList.add('hidden');
@@ -498,108 +503,103 @@ document.getElementById('submitBtn').onclick = async () => {
image_data = await Promise.all(uploadedFiles.map(f => readFileAsBase64(f)));
}
- // 2. 并行启动多个生成任务
- btnText.innerText = `AI 构思中 (0/${num})...`;
- let finishedCount = 0;
+ // 2. 发起单次生成请求 (包含数量 n)
+ btnText.innerText = `AI 构思中...`;
- const startTask = async (index) => {
+ // 预先创建 Slot (对应数量)
+ const slots = [];
+ for (let i = 0; i < num; i++) {
const slot = document.createElement('div');
slot.className = 'image-frame relative bg-white/50 animate-pulse min-h-[200px] flex items-center justify-center rounded-[2.5rem] border border-slate-100 shadow-sm';
slot.innerHTML = `
正在排队中...
`;
grid.appendChild(slot);
+ slots.push(slot);
+ }
- try {
- // 1. 发起生成请求,获取任务 ID
- const r = await fetch('/api/generate', {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({
- mode: currentMode,
- is_premium: document.getElementById('isPremium')?.checked || false,
- apiKey: currentMode === 'key' ? apiKey : '',
- prompt: document.getElementById('manualPrompt').value,
- model: document.getElementById('modelSelect').value,
- ratio: document.getElementById('ratioSelect').value,
- size: document.getElementById('sizeSelect').value,
- image_data // 发送 Base64 数组
- })
- });
- const res = await r.json();
- if (res.error) throw new Error(res.error);
+ // 3. 提交任务
+ const r = await fetch('/api/generate', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({
+ mode: currentMode,
+ n: num, // 发送数量
+ is_premium: document.getElementById('isPremium')?.checked || false,
+ apiKey: currentMode === 'key' ? apiKey : '',
+ prompt: document.getElementById('manualPrompt').value,
+ model: document.getElementById('modelSelect').value,
+ ratio: document.getElementById('ratioSelect').value,
+ size: document.getElementById('sizeSelect').value,
+ image_data // 发送 Base64 数组
+ })
+ });
+ const res = await r.json();
+ if (res.error) throw new Error(res.error);
- // 如果直接返回了 data (比如聊天模型),直接显示
- if (res.data) {
- displayResult(slot, res.data[0]);
- return;
- }
+ // 如果直接返回了 data (比如聊天模型),直接显示在第一个 slot
+ if (res.data) {
+ displayResult(slots[0], res.data[0]);
+ // 移除其他多余的 slots
+ for (let i = 1; i < slots.length; i++) slots[i].remove();
+ return;
+ }
- // 2. 轮询任务状态
- const taskId = res.task_id;
- let pollCount = 0;
- const maxPolls = 500; // 最多轮询约 16 分钟 (2s * 500 = 1000s)
+ // 4. 轮询任务状态
+ const taskId = res.task_id;
+ let pollCount = 0;
+ const maxPolls = 500;
- while (pollCount < maxPolls) {
- await new Promise(resolve => setTimeout(resolve, 2000));
- pollCount++;
+ while (pollCount < maxPolls) {
+ await new Promise(resolve => setTimeout(resolve, 2000));
+ pollCount++;
- const statusR = await fetch(`/api/task_status/${taskId}`);
- const statusRes = await statusR.json();
+ const statusR = await fetch(`/api/task_status/${taskId}`);
+ const statusRes = await statusR.json();
- if (statusRes.status === 'complete') {
- const imgUrl = statusRes.urls[0];
- currentGeneratedUrls.push(imgUrl);
- displayResult(slot, { url: imgUrl });
- finishedCount++;
- btnText.innerText = `AI 构思中 (${finishedCount}/${num})...`;
- if (currentMode === 'trial') checkAuth();
- return; // 任务正常结束
- } else if (statusRes.status === 'error') {
- throw new Error(statusRes.message || "生成失败");
- } else {
- // 继续轮询状态显示
- slot.innerHTML = `AI 正在努力创作中 (${pollCount * 2}s)...
`;
+ if (statusRes.status === 'complete') {
+ const urls = statusRes.urls;
+ currentGeneratedUrls = urls;
+
+ // 填充所有 slots
+ urls.forEach((url, idx) => {
+ if (slots[idx]) {
+ displayResult(slots[idx], { url });
}
- }
- throw new Error("生成超时,请稍后在历史记录中查看");
+ });
- } catch (e) {
- slot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6';
- if (e.message.includes('401') || e.message.includes('请先登录')) {
- slot.innerHTML = `登录已过期,请重新登录
`;
- } else {
- slot.innerHTML = `生成异常: ${e.message}
`;
+ // 如果返回的数量少于 slots 数量,移除多余的
+ if (urls.length < slots.length) {
+ for (let i = urls.length; i < slots.length; i++) slots[i].remove();
}
- }
- };
- // 提取结果展示逻辑
- const displayResult = (slot, data) => {
- if (data.type === 'text') {
- slot.className = 'image-frame relative bg-white border border-slate-100 p-8 rounded-[2.5rem] shadow-xl overflow-y-auto max-h-[600px]';
- slot.innerHTML = `${data.content.replace(/\n/g, '
')}
`;
+ btnText.innerText = `生成完成`;
+ if (currentMode === 'trial') checkAuth();
+ return;
+ } else if (statusRes.status === 'error') {
+ throw new Error(statusRes.message || "生成失败");
} else {
- const imgUrl = data.url;
- slot.className = 'image-frame group relative animate-in zoom-in-95 duration-700 flex flex-col items-center justify-center overflow-hidden bg-white shadow-2xl transition-all hover:shadow-indigo-100/50';
- slot.innerHTML = `
-
-

-
-
-
-
- `;
+ // 更新所有 slot 的轮询状态显示
+ slots.forEach(slot => {
+ slot.innerHTML = `AI 正在努力创作中 (${pollCount * 2}s)...
`;
+ });
}
- lucide.createIcons();
- };
-
- const tasks = Array.from({ length: num }, (_, i) => startTask(i));
- await Promise.all(tasks);
+ }
+ throw new Error("生成超时,请稍后在历史记录中查看");
} catch (e) {
- showToast('创作引擎中断: ' + e.message, 'error');
- document.getElementById('placeholder').classList.remove('hidden');
+ // 在第一个 slot 显示错误
+ if (grid.firstChild) {
+ const firstSlot = grid.firstChild;
+ firstSlot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6';
+ if (e.message.includes('401') || e.message.includes('请先登录')) {
+ firstSlot.innerHTML = `登录已过期,请重新登录
`;
+ } else {
+ firstSlot.innerHTML = `生成异常: ${e.message}
`;
+ }
+ // 移除其他 slots
+ while (grid.children.length > 1) {
+ grid.lastChild.remove();
+ }
+ }
} finally {
btn.disabled = false;
btnText.innerText = "立即生成作品";
@@ -607,6 +607,28 @@ document.getElementById('submitBtn').onclick = async () => {
}
};
+// 提取结果展示逻辑
+const displayResult = (slot, data) => {
+ if (data.type === 'text') {
+ slot.className = 'image-frame relative bg-white border border-slate-100 p-8 rounded-[2.5rem] shadow-xl overflow-y-auto max-h-[600px]';
+ slot.innerHTML = `${data.content.replace(/\n/g, '
')}
`;
+ } else {
+ const imgUrl = data.url;
+ slot.className = 'image-frame group relative animate-in zoom-in-95 duration-700 flex flex-col items-center justify-center overflow-hidden bg-white shadow-2xl transition-all hover:shadow-indigo-100/50';
+ slot.innerHTML = `
+
+

+
+
+
+
+ `;
+ }
+ lucide.createIcons();
+};
+
init();
// 修改密码弹窗控制
From 1202291e4b9f4a4c28ed521ca15dd980c15e318e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=85=AC=E5=8F=B8git?= <240241002@qq.com>
Date: Tue, 20 Jan 2026 09:53:06 +0800
Subject: [PATCH 2/3] feat: Implement AI image, video, and chat generation
features with integrated point and API key management.
---
blueprints/api.py | 5 +-
services/generation_service.py | 26 +++---
static/js/main.js | 151 +++++++++++++++------------------
3 files changed, 81 insertions(+), 101 deletions(-)
diff --git a/blueprints/api.py b/blueprints/api.py
index 72fee15..1ea786e 100644
--- a/blueprints/api.py
+++ b/blueprints/api.py
@@ -68,7 +68,7 @@ def generate():
# 2. 扣除积分 (如果是试用模式)
if use_trial:
- deduct_points(user, cost)
+ deduct_points(user_id, cost)
model_value = data.get('model')
prompt = data.get('prompt')
@@ -84,7 +84,6 @@ def generate():
"prompt": prompt,
"model": model_value,
"response_format": "url",
- "n": int(data.get('n', 1)), # 添加数量
"aspect_ratio": data.get('ratio')
}
image_data = data.get('image_data', [])
@@ -120,7 +119,7 @@ def video_generate():
return jsonify({"error": error}), 400
# 2. 扣除积分
- deduct_points(user, cost)
+ deduct_points(user_id, cost)
# 3. 构造 Payload
payload = {
diff --git a/services/generation_service.py b/services/generation_service.py
index d194ede..821a9c9 100644
--- a/services/generation_service.py
+++ b/services/generation_service.py
@@ -28,7 +28,6 @@ def validate_generation_request(user, data):
is_premium = data.get('is_premium', False)
input_key = data.get('apiKey')
model_value = data.get('model')
- n = int(data.get('n', 1)) # 获取请求中的数量
target_api = Config.AI_API
api_key = None
@@ -52,12 +51,9 @@ def validate_generation_request(user, data):
return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式"
# 计算单价
- base_cost = get_model_cost(model_value, is_video=False)
+ cost = get_model_cost(model_value, is_video=False)
if use_trial and is_premium:
- base_cost *= 2
-
- # 总消耗 = 单价 * 数量
- cost = base_cost * n
+ cost *= 2
if use_trial:
if user.points < cost:
@@ -65,21 +61,23 @@ def validate_generation_request(user, data):
return api_key, target_api, cost, use_trial, None
-def deduct_points(user, cost):
- """扣除积分"""
- user.points -= cost
- user.has_used_points = True
- db.session.commit()
+def deduct_points(user_id, cost):
+ """原子扣除积分"""
+ user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first()
+ if user:
+ user.points -= cost
+ user.has_used_points = True
+ db.session.commit()
def refund_points(user_id, cost):
- """退还积分"""
+ """原子退还积分"""
try:
- user = db.session.get(User, user_id)
+ user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first()
if user:
user.points += cost
db.session.commit()
except:
- pass
+ db.session.rollback()
def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost):
"""同步处理对话类模型"""
diff --git a/static/js/main.js b/static/js/main.js
index 0a28ece..fb6b1fb 100644
--- a/static/js/main.js
+++ b/static/js/main.js
@@ -503,103 +503,86 @@ document.getElementById('submitBtn').onclick = async () => {
image_data = await Promise.all(uploadedFiles.map(f => readFileAsBase64(f)));
}
- // 2. 发起单次生成请求 (包含数量 n)
- btnText.innerText = `AI 构思中...`;
+ // 2. 并行启动多个生成任务
+ btnText.innerText = `AI 构思中 (0/${num})...`;
+ let finishedCount = 0;
- // 预先创建 Slot (对应数量)
- const slots = [];
- for (let i = 0; i < num; i++) {
+ const startTask = async (index) => {
const slot = document.createElement('div');
slot.className = 'image-frame relative bg-white/50 animate-pulse min-h-[200px] flex items-center justify-center rounded-[2.5rem] border border-slate-100 shadow-sm';
slot.innerHTML = `正在排队中...
`;
grid.appendChild(slot);
- slots.push(slot);
- }
- // 3. 提交任务
- const r = await fetch('/api/generate', {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({
- mode: currentMode,
- n: num, // 发送数量
- is_premium: document.getElementById('isPremium')?.checked || false,
- apiKey: currentMode === 'key' ? apiKey : '',
- prompt: document.getElementById('manualPrompt').value,
- model: document.getElementById('modelSelect').value,
- ratio: document.getElementById('ratioSelect').value,
- size: document.getElementById('sizeSelect').value,
- image_data // 发送 Base64 数组
- })
- });
- const res = await r.json();
- if (res.error) throw new Error(res.error);
-
- // 如果直接返回了 data (比如聊天模型),直接显示在第一个 slot
- if (res.data) {
- displayResult(slots[0], res.data[0]);
- // 移除其他多余的 slots
- for (let i = 1; i < slots.length; i++) slots[i].remove();
- return;
- }
-
- // 4. 轮询任务状态
- const taskId = res.task_id;
- let pollCount = 0;
- const maxPolls = 500;
-
- while (pollCount < maxPolls) {
- await new Promise(resolve => setTimeout(resolve, 2000));
- pollCount++;
-
- const statusR = await fetch(`/api/task_status/${taskId}`);
- const statusRes = await statusR.json();
-
- if (statusRes.status === 'complete') {
- const urls = statusRes.urls;
- currentGeneratedUrls = urls;
-
- // 填充所有 slots
- urls.forEach((url, idx) => {
- if (slots[idx]) {
- displayResult(slots[idx], { url });
- }
+ try {
+ // 1. 发起生成请求,获取任务 ID
+ const r = await fetch('/api/generate', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({
+ mode: currentMode,
+ is_premium: document.getElementById('isPremium')?.checked || false,
+ apiKey: currentMode === 'key' ? apiKey : '',
+ prompt: document.getElementById('manualPrompt').value,
+ model: document.getElementById('modelSelect').value,
+ ratio: document.getElementById('ratioSelect').value,
+ size: document.getElementById('sizeSelect').value,
+ image_data // 发送 Base64 数组
+ })
});
+ const res = await r.json();
+ if (res.error) throw new Error(res.error);
- // 如果返回的数量少于 slots 数量,移除多余的
- if (urls.length < slots.length) {
- for (let i = urls.length; i < slots.length; i++) slots[i].remove();
+ // 如果直接返回了 data (比如聊天模型),直接显示
+ if (res.data) {
+ displayResult(slot, res.data[0]);
+ return;
}
- btnText.innerText = `生成完成`;
- if (currentMode === 'trial') checkAuth();
- return;
- } else if (statusRes.status === 'error') {
- throw new Error(statusRes.message || "生成失败");
- } else {
- // 更新所有 slot 的轮询状态显示
- slots.forEach(slot => {
- slot.innerHTML = `AI 正在努力创作中 (${pollCount * 2}s)...
`;
- });
+ // 2. 轮询任务状态
+ const taskId = res.task_id;
+ let pollCount = 0;
+ const maxPolls = 500; // 最多轮询约 16 分钟 (2s * 500 = 1000s)
+
+ while (pollCount < maxPolls) {
+ await new Promise(resolve => setTimeout(resolve, 2000));
+ pollCount++;
+
+ const statusR = await fetch(`/api/task_status/${taskId}`);
+ const statusRes = await statusR.json();
+
+ if (statusRes.status === 'complete') {
+ const imgUrl = statusRes.urls[0];
+ currentGeneratedUrls.push(imgUrl);
+ displayResult(slot, { url: imgUrl });
+ finishedCount++;
+ btnText.innerText = `AI 构思中 (${finishedCount}/${num})...`;
+ if (currentMode === 'trial') checkAuth();
+ return; // 任务正常结束
+ } else if (statusRes.status === 'error') {
+ throw new Error(statusRes.message || "生成失败");
+ } else {
+ // 继续轮询状态显示
+ slot.innerHTML = `AI 正在努力创作中 (${pollCount * 2}s)...
`;
+ }
+ }
+ throw new Error("生成超时,请稍后在历史记录中查看");
+
+ } catch (e) {
+ slot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6';
+ if (e.message.includes('401') || e.message.includes('请先登录')) {
+ slot.innerHTML = `登录已过期,请重新登录
`;
+ } else {
+ slot.innerHTML = `生成异常: ${e.message}
`;
+ }
}
- }
- throw new Error("生成超时,请稍后在历史记录中查看");
+ };
+
+ const tasks = Array.from({ length: num }, (_, i) => startTask(i));
+ await Promise.all(tasks);
} catch (e) {
- // 在第一个 slot 显示错误
- if (grid.firstChild) {
- const firstSlot = grid.firstChild;
- firstSlot.className = 'image-frame relative bg-red-50/50 flex items-center justify-center rounded-[2.5rem] border border-red-100 p-6';
- if (e.message.includes('401') || e.message.includes('请先登录')) {
- firstSlot.innerHTML = `登录已过期,请重新登录
`;
- } else {
- firstSlot.innerHTML = `生成异常: ${e.message}
`;
- }
- // 移除其他 slots
- while (grid.children.length > 1) {
- grid.lastChild.remove();
- }
- }
+ showToast('创作引擎中断: ' + e.message, 'error');
+ document.getElementById('placeholder').classList.remove('hidden');
} finally {
btn.disabled = false;
btnText.innerText = "立即生成作品";
From 2453bb05eae35ecbab8dc45daf9d4c4d0cf6f63a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=85=AC=E5=8F=B8git?= <240241002@qq.com>
Date: Tue, 20 Jan 2026 16:01:58 +0800
Subject: [PATCH 3/3] feat: Implement core API endpoints for AI content
generation, user management, and Alipay payment processing.
---
blueprints/api.py | 2 +-
blueprints/payment.py | 5 ++---
services/generation_service.py | 6 +++---
services/task_service.py | 31 ++++++++++++++-----------------
4 files changed, 20 insertions(+), 24 deletions(-)
diff --git a/blueprints/api.py b/blueprints/api.py
index 1ea786e..cf72d0a 100644
--- a/blueprints/api.py
+++ b/blueprints/api.py
@@ -95,7 +95,7 @@ def generate():
# 5. 启动异步生图任务
app = current_app._get_current_object()
- task_id = start_async_image_task(app, user_id, payload, api_key, target_api, cost, data.get('mode'), model_value)
+ task_id = start_async_image_task(app, user_id, payload, api_key, target_api, cost, data.get('mode'), model_value, use_trial)
return jsonify({
"task_id": task_id,
diff --git a/blueprints/payment.py b/blueprints/payment.py
index ae426a2..2bc2910 100644
--- a/blueprints/payment.py
+++ b/blueprints/payment.py
@@ -4,7 +4,7 @@ from models import Order, User, to_bj_time, get_bj_now
from services.alipay_service import AlipayService
from services.logger import system_logger
import uuid
-from datetime import timedelta
+from datetime import datetime, timedelta
payment_bp = Blueprint('payment', __name__, url_prefix='/payment')
@@ -102,8 +102,7 @@ def payment_history():
)
).order_by(Order.created_at.desc()).all()
- import datetime as dt_module
- return render_template('recharge_history.html', orders=orders, modules={'datetime': dt_module})
+ return render_template('recharge_history.html', orders=orders, modules={'datetime': datetime})
@payment_bp.route('/api/history', methods=['GET'])
def api_payment_history():
diff --git a/services/generation_service.py b/services/generation_service.py
index 821a9c9..9a5380f 100644
--- a/services/generation_service.py
+++ b/services/generation_service.py
@@ -117,7 +117,7 @@ def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial
refund_points(user_id, cost)
return {"error": str(e)}, 500
-def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mode, model_value):
+def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mode, model_value, use_trial=False):
"""启动异步生图任务"""
task_id = str(uuid.uuid4())
@@ -126,7 +126,7 @@ def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mod
threading.Thread(
target=process_image_generation,
- args=(app, user_id, task_id, payload, api_key, target_api, cost)
+ args=(app, user_id, task_id, payload, api_key, target_api, cost, use_trial)
).start()
return task_id
@@ -153,7 +153,7 @@ def start_async_video_task(app, user_id, payload, cost, model_value):
threading.Thread(
target=process_video_generation,
- args=(app, user_id, task_id, payload, api_key, cost)
+ args=(app, user_id, task_id, payload, api_key, cost, True) # 视频目前默认为积分模式
).start()
return task_id
diff --git a/services/task_service.py b/services/task_service.py
index e873356..a216a66 100644
--- a/services/task_service.py
+++ b/services/task_service.py
@@ -90,7 +90,7 @@ def sync_images_background(app, record_id, raw_urls):
except Exception as e:
print(f"❌ 更新记录失败: {e}")
-def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost):
+def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost, use_trial=False):
"""异步执行图片生成并存入 Redis"""
with app.app_context():
try:
@@ -99,10 +99,9 @@ def process_image_generation(app, user_id, task_id, payload, api_key, target_api
resp = requests.post(target_api, json=payload, headers=headers, timeout=1000)
if resp.status_code != 200:
- user = db.session.get(User, user_id)
- if user and "sk-" in api_key:
- user.points += cost
- db.session.commit()
+ if use_trial:
+ from services.generation_service import refund_points
+ refund_points(user_id, cost)
# 记录详细的失败上下文
system_logger.error(f"生图任务失败: {resp.text}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model'))
@@ -135,10 +134,9 @@ def process_image_generation(app, user_id, task_id, payload, api_key, target_api
except Exception as e:
# 异常处理:退还积分
- user = db.session.get(User, user_id)
- if user and "sk-" in api_key:
- user.points += cost
- db.session.commit()
+ if use_trial:
+ from services.generation_service import refund_points
+ refund_points(user_id, cost)
system_logger.error(f"生图任务异常: {str(e)}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model'))
redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))
@@ -203,7 +201,7 @@ def sync_video_background(app, record_id, raw_url, internal_task_id=None):
except Exception as dbe:
system_logger.error(f"更新视频记录失败: {str(dbe)}")
-def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost):
+def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost, use_trial=True):
"""异步提交并查询视频任务状态"""
with app.app_context():
try:
@@ -273,13 +271,12 @@ def process_video_generation(app, user_id, internal_task_id, payload, api_key, c
except Exception as e:
system_logger.error(f"视频生成执行异常: {str(e)}", user_id=user_id, task_id=internal_task_id, prompt=payload.get('prompt'))
# 尝试退费
- try:
- user = db.session.get(User, user_id)
- if user:
- user.points += cost
- db.session.commit()
- except Exception as re:
- system_logger.error(f"退费失败: {str(re)}")
+ if use_trial:
+ try:
+ from services.generation_service import refund_points
+ refund_points(user_id, cost)
+ except Exception as re:
+ system_logger.error(f"退费失败: {str(re)}")
# 确保 Redis 状态一定被更新,防止前端死循环
redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "error", "message": str(e)}))