ai_v/services/generation_service.py

157 lines
5.1 KiB
Python

from config import Config
from models import SystemDict, GenerationRecord, User, db
from services.logger import system_logger
from services.task_service import process_image_generation, process_video_generation
import requests
import json
import uuid
import threading
from flask import current_app
def get_model_cost(model_value, is_video=False):
"""获取模型消耗积分"""
dict_type = 'video_model' if is_video else 'ai_model'
model_dict = SystemDict.query.filter_by(dict_type=dict_type, value=model_value).first()
if model_dict:
return model_dict.cost
# Default costs
if is_video:
return 15 if "pro" in model_value.lower() or "3.1" in model_value else 10
else:
return 1
def validate_generation_request(user, data):
"""验证生图请求并返回配置 (api_key, target_api, cost, use_trial)"""
mode = data.get('mode', 'trial')
is_premium = data.get('is_premium', False)
input_key = data.get('apiKey')
model_value = data.get('model')
target_api = Config.AI_API
api_key = None
use_trial = False
if mode == 'key':
api_key = input_key or user.api_key
if not api_key:
return None, None, 0, False, "请先输入您的 API 密钥"
# Update user key if changed
if input_key and input_key != user.api_key:
user.api_key = input_key
db.session.commit()
else:
if user.points > 0:
api_key = Config.PREMIUM_KEY if is_premium else Config.TRIAL_KEY
target_api = Config.TRIAL_API
use_trial = True
else:
return None, None, 0, False, "可用积分已耗尽,请充值或切换至自定义 Key 模式"
cost = get_model_cost(model_value, is_video=False)
if use_trial and is_premium:
cost *= 2
if use_trial:
if user.points < cost:
return None, None, cost, True, "可用积分不足"
return api_key, target_api, cost, use_trial, None
def deduct_points(user, cost):
"""扣除积分"""
user.points -= cost
user.has_used_points = True
db.session.commit()
def refund_points(user_id, cost):
"""退还积分"""
try:
user = db.session.get(User, user_id)
if user:
user.points += cost
db.session.commit()
except:
pass
def handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost):
"""同步处理对话类模型"""
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
chat_payload = {
"model": model_value,
"messages": [{"role": "user", "content": prompt}]
}
try:
resp = requests.post(Config.CHAT_API, json=chat_payload, headers=headers, timeout=120)
if resp.status_code != 200:
if use_trial:
refund_points(user_id, cost)
return {"error": resp.text}, resp.status_code
api_result = resp.json()
content = api_result['choices'][0]['message']['content']
# 记录聊天历史
if prompt != "解读验光单":
new_record = GenerationRecord(
user_id=user_id,
prompt=prompt,
model=model_value,
cost=cost,
image_urls=json.dumps([{"type": "text", "content": content}])
)
db.session.add(new_record)
db.session.commit()
return {
"data": [{"content": content, "type": "text"}],
"message": "生成成功!"
}, 200
except Exception as e:
if use_trial:
refund_points(user_id, cost)
return {"error": str(e)}, 500
def start_async_image_task(app, user_id, payload, api_key, target_api, cost, mode, model_value):
"""启动异步生图任务"""
task_id = str(uuid.uuid4())
log_msg = "用户发起验光单解读" if payload.get('prompt') == "解读验光单" else "用户发起生图任务"
system_logger.info(log_msg, model=model_value, mode=mode)
threading.Thread(
target=process_image_generation,
args=(app, user_id, task_id, payload, api_key, target_api, cost)
).start()
return task_id
def validate_video_request(user, data):
"""验证视频生成请求"""
if user.points <= 0:
return None, 0, "可用积分不足,请先充值"
model_value = data.get('model', 'veo3.1')
cost = get_model_cost(model_value, is_video=True)
if user.points < cost:
return None, cost, f"积分不足,生成该视频需要 {cost} 积分"
return model_value, cost, None
def start_async_video_task(app, user_id, payload, cost, model_value):
"""启动异步视频任务"""
api_key = Config.TRIAL_KEY
task_id = str(uuid.uuid4())
system_logger.info("用户发起视频生成任务 (积分模式)", model=model_value, cost=cost)
threading.Thread(
target=process_video_generation,
args=(app, user_id, task_id, payload, api_key, cost)
).start()
return task_id