ai_v/blueprints/api.py

226 lines
7.6 KiB
Python

from flask import Blueprint, request, jsonify, session, current_app
from extensions import db, redis_client
from models import User
from middlewares.auth import login_required
from services.logger import system_logger
import json
# Import Services
from services.system_service import get_system_config_data, get_user_latest_notification, mark_notification_as_read
from services.history_service import get_user_history_data
from services.file_service import handle_file_uploads, get_remote_file_stream
from services.generation_service import (
validate_generation_request, deduct_points, handle_chat_generation_sync,
start_async_image_task, validate_video_request, start_async_video_task
)
api_bp = Blueprint('api', __name__)
@api_bp.route('/api/task_status/<task_id>')
@login_required
def get_task_status(task_id):
"""查询异步任务状态"""
try:
data = redis_client.get(f"task:{task_id}")
if not data:
return jsonify({"status": "pending"})
if isinstance(data, bytes):
data = data.decode('utf-8')
return jsonify(json.loads(data))
except Exception as e:
system_logger.error(f"查询任务状态异常: {str(e)}")
return jsonify({"status": "error", "message": "状态查询失败"})
@api_bp.route('/api/config')
def get_config():
"""从本地数据库字典获取配置"""
try:
return jsonify(get_system_config_data())
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/upload', methods=['POST'])
@login_required
def upload():
try:
files = request.files.getlist('images')
img_urls = handle_file_uploads(files)
system_logger.info(f"用户上传文件: {len(files)}", user_id=session.get('user_id'))
return jsonify({"urls": img_urls})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/generate', methods=['POST'])
@login_required
def generate():
try:
user_id = session.get('user_id')
user = db.session.get(User, user_id)
data = request.json if request.is_json else request.form
# 1. 验证请求与权限
api_key, target_api, cost, use_trial, error = validate_generation_request(user, data)
if error:
return jsonify({"error": error}), 400
# 2. 扣除积分 (如果是试用模式)
if use_trial:
deduct_points(user_id, cost)
model_value = data.get('model')
prompt = data.get('prompt')
is_chat_model = "gemini" in model_value.lower() or "gpt" in model_value.lower()
# 3. 处理聊天模型 (同步)
if is_chat_model:
result, status_code = handle_chat_generation_sync(user_id, api_key, model_value, prompt, use_trial, cost)
return jsonify(result), status_code
# 4. 构造生图 Payload
payload = {
"prompt": prompt,
"model": model_value,
"response_format": "url",
"aspect_ratio": data.get('ratio')
}
image_data = data.get('image_data', [])
if image_data:
payload["image"] = [img.split(',', 1)[1] if ',' in img else img for img in image_data]
if model_value == "nano-banana-2" and data.get('size'):
payload["image_size"] = data.get('size')
# 5. 启动异步生图任务
app = current_app._get_current_object()
task_id = start_async_image_task(app, user_id, payload, api_key, target_api, cost, data.get('mode'), model_value)
return jsonify({
"task_id": task_id,
"message": "已开启异步生成任务"
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/video/generate', methods=['POST'])
@login_required
def video_generate():
try:
user_id = session.get('user_id')
user = db.session.get(User, user_id)
data = request.json
# 1. 验证请求
model_value, cost, error = validate_video_request(user, data)
if error:
return jsonify({"error": error}), 400
# 2. 扣除积分
deduct_points(user_id, cost)
# 3. 构造 Payload
payload = {
"model": model_value,
"prompt": data.get('prompt'),
"enhance_prompt": data.get('enhance_prompt', False),
"images": data.get('images', []),
"aspect_ratio": data.get('aspect_ratio', '9:16')
}
# 4. 启动异步视频任务
app = current_app._get_current_object()
task_id = start_async_video_task(app, user_id, payload, cost, model_value)
return jsonify({
"task_id": task_id,
"message": "视频生成任务已提交,系统正在导演中..."
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/notifications/latest', methods=['GET'])
@login_required
def get_latest_notification():
try:
user_id = session.get('user_id')
data = get_user_latest_notification(user_id)
return jsonify(data)
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/notifications/read', methods=['POST'])
@login_required
def mark_notif_read():
try:
user_id = session.get('user_id')
data = request.json
notif_id = data.get('id')
if not notif_id:
return jsonify({"error": "缺少通知 ID"}), 400
mark_notification_as_read(user_id, notif_id)
return jsonify({"status": "ok"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/history', methods=['GET'])
@login_required
def get_history():
try:
user_id = session.get('user_id')
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
filter_type = request.args.get('filter_type', 'all')
data = get_user_history_data(user_id, page, per_page, filter_type)
return jsonify(data)
except Exception as e:
return jsonify({"error": str(e)}), 500
@api_bp.route('/api/stats/points', methods=['GET'])
@login_required
def get_user_point_stats():
"""获取积分统计图表数据"""
from services.stats_service import get_point_stats
user_id = session.get('user_id')
days = request.args.get('days', 7, type=int)
return jsonify(get_point_stats(user_id, days))
@api_bp.route('/api/stats/details', methods=['GET'])
@login_required
def get_user_point_details():
"""获取积分消耗明细"""
from services.stats_service import get_point_details
user_id = session.get('user_id')
page = request.args.get('page', 1, type=int)
return jsonify(get_point_details(user_id, page))
@api_bp.route('/api/download_proxy', methods=['GET'])
@login_required
def download_proxy():
import time
url = request.args.get('url')
# 默认文件名逻辑
default_name = f"video-{int(time.time())}.mp4"
filename = request.args.get('filename') or default_name
if not url:
return jsonify({"error": "缺少 URL 参数"}), 400
try:
req, headers = get_remote_file_stream(url)
headers['Content-Disposition'] = f'attachment; filename="{filename}"'
def generate():
for chunk in req.iter_content(chunk_size=4096):
yield chunk
return current_app.response_class(generate(), headers=headers)
except Exception as e:
system_logger.error(f"代理下载失败: {str(e)}")
return jsonify({"error": "下载失败"}), 500