160 lines
5.4 KiB
Python
160 lines
5.4 KiB
Python
from config import Config
|
|
from models import SystemDict, GenerationRecord, User, db
|
|
from services.logger import system_logger
|
|
from services.task_service import process_image_generation, process_video_generation
|
|
import requests
|
|
import json
|
|
import uuid
|
|
import threading
|
|
from flask import current_app
|
|
|
|
def get_model_cost(model_value, is_video=False):
|
|
"""获取模型消耗积分"""
|
|
dict_type = 'video_model' if is_video else 'ai_model'
|
|
model_dict = SystemDict.query.filter_by(dict_type=dict_type, value=model_value).first()
|
|
|
|
if model_dict:
|
|
return model_dict.cost
|
|
|
|
# Default costs
|
|
if is_video:
|
|
return 15 if "pro" in model_value.lower() or "3.1" in model_value else 10
|
|
else:
|
|
return 1
|
|
|
|
def validate_generation_request(user, data):
|
|
"""验证生图请求并返回配置 (api_key, target_api, cost, use_trial)"""
|
|
mode = data.get('mode', 'trial')
|
|
is_premium = data.get('is_premium', False)
|
|
input_key = data.get('apiKey')
|
|
model_value = data.get('model')
|
|
|
|
target_api = Config.AI_API
|
|
api_key = None
|
|
use_trial = False
|
|
|
|
if mode == 'key':
|
|
api_key = input_key or user.api_key
|
|
if not api_key:
|
|
return None, None, 0, False, "请先输入您的 API 密钥"
|
|
|
|
# Update user key if changed
|
|
if input_key and input_key != user.api_key:
|
|
user.api_key = input_key
|
|
db.session.commit()
|
|
else:
|
|
if user.points > 0:
|
|
api_key = Config.PREMIUM_KEY if is_premium else Config.TRIAL_KEY
|
|
target_api = Config.TRIAL_API
|
|
use_trial = True
|
|
else:
|
|
return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式"
|
|
|
|
# 计算单价
|
|
cost = get_model_cost(model_value, is_video=False)
|
|
if use_trial and is_premium:
|
|
cost *= 2
|
|
|
|
if use_trial:
|
|
if user.points < cost:
|
|
return None, None, cost, True, f"可用积分不足(本次需要 {cost} 积分)"
|
|
|
|
return api_key, target_api, cost, use_trial, None
|
|
|
|
def deduct_points(user_id, cost):
|
|
"""原子扣除积分"""
|
|
user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first()
|
|
if user:
|
|
user.points -= cost
|
|
user.has_used_points = True
|
|
db.session.commit()
|
|
|
|
def refund_points(user_id, cost):
|
|
"""原子退还积分"""
|
|
try:
|
|
user = db.session.query(User).filter_by(id=user_id).populate_existing().with_for_update().first()
|
|
if user:
|
|
user.points += cost
|
|
db.session.commit()
|
|
except:
|
|
db.session.rollback()
|
|
|
|
def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost):
|
|
"""同步处理对话类模型"""
|
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
chat_payload = {
|
|
"model": model_value,
|
|
"messages": [{"role": "user", "content": prompt}]
|
|
}
|
|
try:
|
|
resp = requests.post(Config.CHAT_API, json=chat_payload, headers=headers, timeout=120)
|
|
if resp.status_code != 200:
|
|
if use_trial:
|
|
refund_points(user_id, cost)
|
|
return {"error": resp.text}, resp.status_code
|
|
|
|
api_result = resp.json()
|
|
content = api_result['choices'][0]['message']['content']
|
|
|
|
# 记录聊天历史
|
|
if prompt != "解读验光单":
|
|
new_record = GenerationRecord(
|
|
user_id=user_id,
|
|
prompt=prompt,
|
|
model=model_value,
|
|
cost=cost,
|
|
image_urls=json.dumps([{"type": "text", "content": content}])
|
|
)
|
|
db.session.add(new_record)
|
|
db.session.commit()
|
|
|
|
return {
|
|
"data": [{"content": content, "type": "text"}],
|
|
"message": "生成成功!"
|
|
}, 200
|
|
except Exception as e:
|
|
if use_trial:
|
|
refund_points(user_id, cost)
|
|
return {"error": str(e)}, 500
|
|
|
|
def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mode, model_value, use_trial=False):
|
|
"""启动异步生图任务"""
|
|
task_id = str(uuid.uuid4())
|
|
|
|
log_msg = "用户发起验光单解读" if payload.get('prompt') == "解读验光单" else "用户发起生图任务"
|
|
system_logger.info(log_msg, model=model_value, mode=mode)
|
|
|
|
threading.Thread(
|
|
target=process_image_generation,
|
|
args=(app, user_id, task_id, payload, api_key, target_api, cost, use_trial)
|
|
).start()
|
|
|
|
return task_id
|
|
|
|
def validate_video_request(user, data):
|
|
"""验证视频生成请求"""
|
|
if user.points <= 0:
|
|
return None, 0, "可用积分不足,请先充值"
|
|
|
|
model_value = data.get('model', 'veo3.1')
|
|
cost = get_model_cost(model_value, is_video=True)
|
|
|
|
if user.points < cost:
|
|
return None, cost, f"积分不足,生成该视频需要 {cost} 积分"
|
|
|
|
return model_value, cost, None
|
|
|
|
def start_async_video_task(app, user_id, payload, cost, model_value):
|
|
"""启动异步视频任务"""
|
|
api_key = Config.TRIAL_KEY
|
|
task_id = str(uuid.uuid4())
|
|
|
|
system_logger.info("用户发起视频生成任务 (积分模式)", model=model_value, cost=cost)
|
|
|
|
threading.Thread(
|
|
target=process_video_generation,
|
|
args=(app, user_id, task_id, payload, api_key, cost, True) # 视频目前默认为积分模式
|
|
).start()
|
|
|
|
return task_id
|