diff --git a/README.md b/README.md index 52f32f1..71dc125 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,14 @@ ### 1. 环境准备 确保已安装 Python 3.8+ 和 PostgreSQL / Redis。 +**创建虚拟环境** +```bash +# 如果 python 命令不可用,请尝试使用 py +python -m venv .venv +# 或者 +py -m venv .venv +``` + **激活虚拟环境** (推荐) ```bash # Windows (PowerShell) diff --git a/blueprints/api.py b/blueprints/api.py index b556c52..cf72d0a 100644 --- a/blueprints/api.py +++ b/blueprints/api.py @@ -68,7 +68,7 @@ def generate(): # 2. 扣除积分 (如果是试用模式) if use_trial: - deduct_points(user, cost) + deduct_points(user_id, cost) model_value = data.get('model') prompt = data.get('prompt') @@ -95,7 +95,7 @@ def generate(): # 5. 启动异步生图任务 app = current_app._get_current_object() - task_id = start_async_image_task(app, user_id, payload, api_key, target_api, cost, data.get('mode'), model_value) + task_id = start_async_image_task(app, user_id, payload, api_key, target_api, cost, data.get('mode'), model_value, use_trial) return jsonify({ "task_id": task_id, @@ -119,7 +119,7 @@ def video_generate(): return jsonify({"error": error}), 400 # 2. 扣除积分 - deduct_points(user, cost) + deduct_points(user_id, cost) # 3. 构造 Payload payload = { diff --git a/blueprints/payment.py b/blueprints/payment.py index ae426a2..2bc2910 100644 --- a/blueprints/payment.py +++ b/blueprints/payment.py @@ -4,7 +4,7 @@ from models import Order, User, to_bj_time, get_bj_now from services.alipay_service import AlipayService from services.logger import system_logger import uuid -from datetime import timedelta +from datetime import datetime, timedelta payment_bp = Blueprint('payment', __name__, url_prefix='/payment') @@ -102,8 +102,7 @@ def payment_history(): ) ).order_by(Order.created_at.desc()).all() - import datetime as dt_module - return render_template('recharge_history.html', orders=orders, modules={'datetime': dt_module}) + return render_template('recharge_history.html', orders=orders, modules={'datetime': datetime}) @payment_bp.route('/api/history', methods=['GET']) def api_payment_history(): diff --git a/services/generation_service.py b/services/generation_service.py index 1276b69..9a5380f 100644 --- a/services/generation_service.py +++ b/services/generation_service.py @@ -50,31 +50,34 @@ def validate_generation_request(user, data): else: return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式" + # 计算单价 cost = get_model_cost(model_value, is_video=False) if use_trial and is_premium: cost *= 2 if use_trial: if user.points < cost: - return None, None, cost, True, "可用积分不足" + return None, None, cost, True, f"可用积分不足(本次需要 {cost} 积分)" return api_key, target_api, cost, use_trial, None -def deduct_points(user, cost): - """扣除积分""" - user.points -= cost - user.has_used_points = True - db.session.commit() +def deduct_points(user_id, cost): + """原子扣除积分""" + user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first() + if user: + user.points -= cost + user.has_used_points = True + db.session.commit() def refund_points(user_id, cost): - """退还积分""" + """原子退还积分""" try: - user = db.session.get(User, user_id) + user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first() if user: user.points += cost db.session.commit() except: - pass + db.session.rollback() def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost): """同步处理对话类模型""" @@ -114,7 +117,7 @@ def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial refund_points(user_id, cost) return {"error": str(e)}, 500 -def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mode, model_value): +def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mode, model_value, use_trial=False): """启动异步生图任务""" task_id = str(uuid.uuid4()) @@ -123,7 +126,7 @@ def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mod threading.Thread( target=process_image_generation, - args=(app, user_id, task_id, payload, api_key, target_api, cost) + args=(app, user_id, task_id, payload, api_key, target_api, cost, use_trial) ).start() return task_id @@ -150,7 +153,7 @@ def start_async_video_task(app, user_id, payload, cost, model_value): threading.Thread( target=process_video_generation, - args=(app, user_id, task_id, payload, api_key, cost) + args=(app, user_id, task_id, payload, api_key, cost, True) # 视频目前默认为积分模式 ).start() return task_id diff --git a/services/task_service.py b/services/task_service.py index e873356..a216a66 100644 --- a/services/task_service.py +++ b/services/task_service.py @@ -90,7 +90,7 @@ def sync_images_background(app, record_id, raw_urls): except Exception as e: print(f"❌ 更新记录失败: {e}") -def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost): +def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost, use_trial=False): """异步执行图片生成并存入 Redis""" with app.app_context(): try: @@ -99,10 +99,9 @@ def process_image_generation(app, user_id, task_id, payload, api_key, target_api resp = requests.post(target_api, json=payload, headers=headers, timeout=1000) if resp.status_code != 200: - user = db.session.get(User, user_id) - if user and "sk-" in api_key: - user.points += cost - db.session.commit() + if use_trial: + from services.generation_service import refund_points + refund_points(user_id, cost) # 记录详细的失败上下文 system_logger.error(f"生图任务失败: {resp.text}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model')) @@ -135,10 +134,9 @@ def process_image_generation(app, user_id, task_id, payload, api_key, target_api except Exception as e: # 异常处理:退还积分 - user = db.session.get(User, user_id) - if user and "sk-" in api_key: - user.points += cost - db.session.commit() + if use_trial: + from services.generation_service import refund_points + refund_points(user_id, cost) system_logger.error(f"生图任务异常: {str(e)}", user_id=user_id, task_id=task_id, prompt=payload.get('prompt'), model=payload.get('model')) redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)})) @@ -203,7 +201,7 @@ def sync_video_background(app, record_id, raw_url, internal_task_id=None): except Exception as dbe: system_logger.error(f"更新视频记录失败: {str(dbe)}") -def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost): +def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost, use_trial=True): """异步提交并查询视频任务状态""" with app.app_context(): try: @@ -273,13 +271,12 @@ def process_video_generation(app, user_id, internal_task_id, payload, api_key, c except Exception as e: system_logger.error(f"视频生成执行异常: {str(e)}", user_id=user_id, task_id=internal_task_id, prompt=payload.get('prompt')) # 尝试退费 - try: - user = db.session.get(User, user_id) - if user: - user.points += cost - db.session.commit() - except Exception as re: - system_logger.error(f"退费失败: {str(re)}") + if use_trial: + try: + from services.generation_service import refund_points + refund_points(user_id, cost) + except Exception as re: + system_logger.error(f"退费失败: {str(re)}") # 确保 Redis 状态一定被更新,防止前端死循环 redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "error", "message": str(e)})) diff --git a/static/js/main.js b/static/js/main.js index 2093565..fb6b1fb 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -195,6 +195,8 @@ async function init() { if (modeTrialBtn) modeTrialBtn.onclick = () => switchMode('trial'); if (modeKeyBtn) modeKeyBtn.onclick = () => switchMode('key'); if (isPremiumCheckbox) isPremiumCheckbox.onchange = () => updateCostPreview(); + const numSelect = document.getElementById('numSelect'); + if (numSelect) numSelect.onchange = () => updateCostPreview(); // 历史记录控制 const historyDrawer = document.getElementById('historyDrawer'); @@ -305,13 +307,16 @@ function fillSelect(id, list) { function updateCostPreview() { const modelSelect = document.getElementById('modelSelect'); const costPreview = document.getElementById('costPreview'); + const numSelect = document.getElementById('numSelect'); const isPremium = document.getElementById('isPremium')?.checked || false; const selectedOption = modelSelect.options[modelSelect.selectedIndex]; if (currentMode === 'trial' && selectedOption) { - let cost = parseInt(selectedOption.getAttribute('data-cost') || 0); - if (isPremium) cost *= 2; // 优质模式 2 倍积分 - costPreview.innerText = `本次生成将消耗 ${cost} 积分`; + let baseCost = parseInt(selectedOption.getAttribute('data-cost') || 0); + let num = parseInt(numSelect?.value || 1); + let totalCost = baseCost * num; + if (isPremium) totalCost *= 2; // 优质模式 2 倍积分 + costPreview.innerText = `本次生成将消耗 ${totalCost} 积分`; costPreview.classList.remove('hidden'); } else { costPreview.classList.add('hidden'); @@ -572,28 +577,6 @@ document.getElementById('submitBtn').onclick = async () => { } }; - // 提取结果展示逻辑 - const displayResult = (slot, data) => { - if (data.type === 'text') { - slot.className = 'image-frame relative bg-white border border-slate-100 p-8 rounded-[2.5rem] shadow-xl overflow-y-auto max-h-[600px]'; - slot.innerHTML = `
${data.content.replace(/\n/g, '
')}
`; - } else { - const imgUrl = data.url; - slot.className = 'image-frame group relative animate-in zoom-in-95 duration-700 flex flex-col items-center justify-center overflow-hidden bg-white shadow-2xl transition-all hover:shadow-indigo-100/50'; - slot.innerHTML = ` -
- -
-
- -
- `; - } - lucide.createIcons(); - }; - const tasks = Array.from({ length: num }, (_, i) => startTask(i)); await Promise.all(tasks); @@ -607,6 +590,28 @@ document.getElementById('submitBtn').onclick = async () => { } }; +// 提取结果展示逻辑 +const displayResult = (slot, data) => { + if (data.type === 'text') { + slot.className = 'image-frame relative bg-white border border-slate-100 p-8 rounded-[2.5rem] shadow-xl overflow-y-auto max-h-[600px]'; + slot.innerHTML = `
${data.content.replace(/\n/g, '
')}
`; + } else { + const imgUrl = data.url; + slot.className = 'image-frame group relative animate-in zoom-in-95 duration-700 flex flex-col items-center justify-center overflow-hidden bg-white shadow-2xl transition-all hover:shadow-indigo-100/50'; + slot.innerHTML = ` +
+ +
+
+ +
+ `; + } + lucide.createIcons(); +}; + init(); // 修改密码弹窗控制