import os import uuid import json import requests import io import threading import time import base64 from flask import Blueprint, request, jsonify, session, current_app from urllib.parse import quote from config import Config from extensions import s3_client, redis_client, db from models import GenerationRecord, User, SystemDict, SystemNotification from middlewares.auth import login_required from services.logger import system_logger api_bp = Blueprint('api', __name__) def sync_images_background(app, record_id, raw_urls): """后台同步图片至 MinIO,并生成缩略图,带重试机制""" with app.app_context(): processed_data = [] for raw_url in raw_urls: success = False for attempt in range(3): # 3 次重试机制 try: img_resp = requests.get(raw_url, timeout=30) if img_resp.status_code == 200: content = img_resp.content ext = ".png" base_filename = f"gen-{uuid.uuid4().hex}" full_filename = f"{base_filename}{ext}" thumb_filename = f"{base_filename}-thumb{ext}" # 1. 上传原图 s3_client.upload_fileobj( io.BytesIO(content), Config.MINIO["bucket"], full_filename, ExtraArgs={"ContentType": "image/png"} ) full_url = f"{Config.MINIO['public_url']}{quote(full_filename)}" thumb_url = full_url # 默认使用原图 # 2. 生成并上传缩略图 (400px 宽度) try: from PIL import Image img = Image.open(io.BytesIO(content)) # 转换为 RGB 如果是 RGBA (避免某些格式保存问题) if img.mode in ("RGBA", "P"): img = img.convert("RGB") # 缩放至宽度 400, 高度等比 w, h = img.size if w > 400: ratio = 400 / float(w) img.thumbnail((400, int(h * ratio)), Image.Resampling.LANCZOS) thumb_io = io.BytesIO() # 缩略图保存为 JPEG 以获得更小的体积 img.save(thumb_io, format='JPEG', quality=80, optimize=True) thumb_io.seek(0) s3_client.upload_fileobj( thumb_io, Config.MINIO["bucket"], thumb_filename.replace('.png', '.jpg'), ExtraArgs={"ContentType": "image/jpeg"} ) thumb_url = f"{Config.MINIO['public_url']}{quote(thumb_filename.replace('.png', '.jpg'))}" except Exception as thumb_e: print(f"⚠️ 缩略图生成失败: {thumb_e}") processed_data.append({"url": full_url, "thumb": thumb_url}) success = True break except Exception as e: print(f"⚠️ 第 {attempt+1} 次同步失败: {e}") time.sleep(2 ** attempt) # 指数退避 if not success: # 如果最终失败,保留原始 URL processed_data.append({"url": raw_url, "thumb": raw_url}) # 更新数据库记录为持久化数据结构 try: record = db.session.get(GenerationRecord, record_id) if record: record.image_urls = json.dumps(processed_data) db.session.commit() print(f"✅ 记录 {record_id} 图片及缩略图已完成同步") except Exception as e: print(f"❌ 更新记录失败: {e}") def process_image_generation(app, user_id, task_id, payload, api_key, target_api, cost): """异步执行图片生成并存入 Redis""" with app.app_context(): try: headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} # 使用较长的超时时间 (10分钟),确保长耗时任务不被中断 resp = requests.post(target_api, json=payload, headers=headers, timeout=1000) if resp.status_code != 200: user = db.session.get(User, user_id) if user and "sk-" in api_key: user.points += cost db.session.commit() system_logger.error(f"生图任务失败: {resp.text}", user_id=user_id, task_id=task_id) redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": resp.text})) return api_result = resp.json() raw_urls = [item['url'] for item in api_result.get('data', [])] # 持久化记录 new_record = GenerationRecord( user_id=user_id, prompt=payload.get('prompt'), model=payload.get('model'), image_urls=json.dumps(raw_urls) ) db.session.add(new_record) db.session.commit() # 后台线程处理:下载 AI 原始图片并同步到私有 MinIO threading.Thread( target=sync_images_background, args=(app, new_record.id, raw_urls) ).start() # 存入 Redis 标记完成 system_logger.info(f"生图任务完成", user_id=user_id, task_id=task_id, model=payload.get('model')) redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "complete", "urls": raw_urls})) except Exception as e: # 异常处理:退还积分 user = db.session.get(User, user_id) if user and "sk-" in api_key: user.points += cost db.session.commit() system_logger.error(f"生图任务异常: {str(e)}", user_id=user_id, task_id=task_id) redis_client.setex(f"task:{task_id}", 3600, json.dumps({"status": "error", "message": str(e)})) def sync_video_background(app, record_id, raw_url, internal_task_id=None): """后台同步视频至 MinIO,带重试机制""" with app.app_context(): success = False final_url = raw_url for attempt in range(3): try: # 增加了流式下载,处理大视频文件 with requests.get(raw_url, stream=True, timeout=120) as r: r.raise_for_status() content_type = r.headers.get('content-type', 'video/mp4') ext = ".mp4" if "text/html" in content_type: # 有些 API 返回的是跳转页面 continue base_filename = f"video-{uuid.uuid4().hex}" full_filename = f"{base_filename}{ext}" video_io = io.BytesIO() for chunk in r.iter_content(chunk_size=8192): video_io.write(chunk) video_io.seek(0) # 上传至 MinIO s3_client.upload_fileobj( video_io, Config.MINIO["bucket"], full_filename, ExtraArgs={"ContentType": content_type} ) final_url = f"{Config.MINIO['public_url']}{quote(full_filename)}" success = True break except Exception as e: system_logger.error(f"同步视频失败 (第{attempt+1}次): {str(e)}") time.sleep(5) if success: try: record = db.session.get(GenerationRecord, record_id) if record: # 更新记录为 MinIO 的 URL record.image_urls = json.dumps([{"url": final_url, "type": "video"}]) db.session.commit() # 关键修复:同步更新 Redis 中的缓存,这样前端轮询也能拿到最新的 MinIO 地址 if internal_task_id: cached_data = redis_client.get(f"task:{internal_task_id}") if cached_data: if isinstance(cached_data, bytes): cached_data = cached_data.decode('utf-8') task_info = json.loads(cached_data) task_info['video_url'] = final_url redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps(task_info)) system_logger.info(f"视频同步 MinIO 成功", video_url=final_url) except Exception as dbe: system_logger.error(f"更新视频记录失败: {str(dbe)}") def process_video_generation(app, user_id, internal_task_id, payload, api_key, cost): """异步提交并查询视频任务状态""" with app.app_context(): try: headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} # 1. 提交任务 submit_resp = requests.post(Config.VIDEO_GEN_API, json=payload, headers=headers, timeout=60) if submit_resp.status_code != 200: raise Exception(f"视频任务提交失败: {submit_resp.text}") submit_result = submit_resp.json() remote_task_id = submit_result.get('task_id') if not remote_task_id: raise Exception(f"未获取到远程任务 ID: {submit_result}") # 2. 轮询状态 max_retries = 90 # 提升到 15 分钟 video_url = None for i in range(max_retries): time.sleep(10) poll_url = Config.VIDEO_POLL_API.format(task_id=remote_task_id) poll_resp = requests.get(poll_url, headers=headers, timeout=30) if poll_resp.status_code != 200: continue poll_result = poll_resp.json() status = poll_result.get('status', '').upper() if status == 'SUCCESS': # 提取视频输出地址 if 'data' in poll_result and isinstance(poll_result['data'], dict): video_url = poll_result['data'].get('output') if not video_url: if 'data' in poll_result and isinstance(poll_result['data'], list) and poll_result['data']: video_url = poll_result['data'][0].get('url') elif 'video' in poll_result: video_url = poll_result['video'].get('url') if isinstance(poll_result['video'], dict) else poll_result['video'] elif 'url' in poll_result: video_url = poll_result['url'] break elif status in ['FAILED', 'ERROR']: raise Exception(f"视频生成失败: {poll_result.get('fail_reason') or poll_result.get('message') or '未知错误'}") if not video_url: raise Exception("超时未获取到视频地址") # 3. 持久化记录 new_record = GenerationRecord( user_id=user_id, prompt=payload.get('prompt'), model=payload.get('model'), image_urls=json.dumps([{"url": video_url, "type": "video"}]) ) db.session.add(new_record) db.session.commit() # 后台线程异步同步到 MinIO threading.Thread( target=sync_video_background, args=(app, new_record.id, video_url, internal_task_id) ).start() # 4. 存入 Redis redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "complete", "video_url": video_url, "record_id": new_record.id})) system_logger.info(f"视频生成任务完成", user_id=user_id, task_id=internal_task_id) except Exception as e: system_logger.error(f"视频生成执行异常: {str(e)}", user_id=user_id, task_id=internal_task_id) # 尝试退费 try: user = db.session.get(User, user_id) if user: user.points += cost db.session.commit() except Exception as re: system_logger.error(f"退费失败: {str(re)}") # 确保 Redis 状态一定被更新,防止前端死循环 redis_client.setex(f"task:{internal_task_id}", 3600, json.dumps({"status": "error", "message": str(e)})) @api_bp.route('/api/task_status/') @login_required def get_task_status(task_id): """查询异步任务状态""" try: data = redis_client.get(f"task:{task_id}") if not data: return jsonify({"status": "pending"}) # 兼容处理 bytes 和 str if isinstance(data, bytes): data = data.decode('utf-8') return jsonify(json.loads(data)) except Exception as e: system_logger.error(f"查询任务状态异常: {str(e)}") return jsonify({"status": "error", "message": "状态查询失败"}) @api_bp.route('/api/config') def get_config(): """从本地数据库字典获取配置""" try: dicts = SystemDict.query.filter_by(is_active=True).order_by(SystemDict.sort_order.desc()).all() config = { "models": [], "ratios": [], "prompts": [], "sizes": [], "video_models": [], "video_prompts": [] } for d in dicts: item = {"label": d.label, "value": d.value} if d.dict_type == 'ai_model': item["cost"] = d.cost config["models"].append(item) elif d.dict_type == 'aspect_ratio': config["ratios"].append(item) elif d.dict_type == 'prompt_tpl': config["prompts"].append(item) elif d.dict_type == 'ai_image_size': config["sizes"].append(item) elif d.dict_type == 'video_model': item["cost"] = d.cost config["video_models"].append(item) elif d.dict_type == 'video_prompt': config["video_prompts"].append(item) return jsonify(config) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/upload', methods=['POST']) @login_required def upload(): try: files = request.files.getlist('images') img_urls = [] for f in files: ext = os.path.splitext(f.filename)[1] filename = f"{uuid.uuid4().hex}{ext}" s3_client.upload_fileobj( f, Config.MINIO["bucket"], filename, ExtraArgs={"ContentType": f.content_type} ) img_urls.append(f"{Config.MINIO['public_url']}{quote(filename)}") system_logger.info(f"用户上传文件: {len(files)} 个", user_id=session.get('user_id')) return jsonify({"urls": img_urls}) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/generate', methods=['POST']) @login_required def generate(): try: user_id = session.get('user_id') user = db.session.get(User, user_id) data = request.json if request.is_json else request.form mode = data.get('mode', 'trial') is_premium = data.get('is_premium', False) input_key = data.get('apiKey') target_api = Config.AI_API api_key = None use_trial = False if mode == 'key': api_key = input_key or user.api_key if not api_key: return jsonify({"error": "请先输入您的 API 密钥"}), 400 else: if user.points > 0: api_key = Config.PREMIUM_KEY if is_premium else Config.TRIAL_KEY target_api = Config.TRIAL_API use_trial = True else: return jsonify({"error": "可用积分已耗尽,请充值或切换至自定义 Key 模式"}), 400 if mode == 'key' and input_key and input_key != user.api_key: user.api_key = input_key db.session.commit() model_value = data.get('model') is_chat_model = "gemini" in model_value.lower() or "gpt" in model_value.lower() model_dict = SystemDict.query.filter_by(dict_type='ai_model', value=model_value).first() cost = model_dict.cost if model_dict else 1 if use_trial and is_premium: cost *= 2 if use_trial: if user.points < cost: return jsonify({"error": f"可用积分不足"}), 400 user.points -= cost user.has_used_points = True # 标记已使用过积分 db.session.commit() prompt = data.get('prompt') ratio = data.get('ratio') size = data.get('size') image_data = data.get('image_data', []) payload = { "prompt": prompt, "model": model_value, "response_format": "url", "aspect_ratio": ratio } if image_data: payload["image"] = [img.split(',', 1)[1] if ',' in img else img for img in image_data] if model_value == "nano-banana-2" and size: payload["image_size"] = size # 如果是聊天模型,直接同步处理 if is_chat_model: headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} chat_payload = { "model": model_value, "messages": [{"role": "user", "content": prompt}] } resp = requests.post(Config.CHAT_API, json=chat_payload, headers=headers, timeout=120) if resp.status_code != 200: if use_trial: user.points += cost db.session.commit() return jsonify({"error": resp.text}), resp.status_code api_result = resp.json() content = api_result['choices'][0]['message']['content'] # 记录聊天历史 if prompt != "解读验光单": new_record = GenerationRecord( user_id=user_id, prompt=prompt, model=model_value, image_urls=json.dumps([{"type": "text", "content": content}]) ) db.session.add(new_record) db.session.commit() return jsonify({ "data": [{"content": content, "type": "text"}], "message": "生成成功!" }) # --- 异步处理生图任务 --- task_id = str(uuid.uuid4()) app = current_app._get_current_object() log_msg = "用户发起验光单解读" if prompt == "解读验光单" else "用户发起生图任务" system_logger.info(log_msg, model=model_value, mode=mode) threading.Thread( target=process_image_generation, args=(app, user_id, task_id, payload, api_key, target_api, cost) ).start() return jsonify({ "task_id": task_id, "message": "已开启异步生成任务" }) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/video/generate', methods=['POST']) @login_required def video_generate(): try: user_id = session.get('user_id') user = db.session.get(User, user_id) data = request.json # 视频生成统一使用积分模式,隐藏 Key 模式 if user.points <= 0: return jsonify({"error": "可用积分不足,请先充值"}), 400 model_value = data.get('model', 'veo3.1') # 确定积分消耗 (优先从字典获取) model_dict = SystemDict.query.filter_by(dict_type='video_model', value=model_value).first() cost = model_dict.cost if model_dict else (15 if "pro" in model_value.lower() or "3.1" in model_value else 10) if user.points < cost: return jsonify({"error": f"积分不足,生成该视频需要 {cost} 积分"}), 400 # 扣除积分 user.points -= cost user.has_used_points = True db.session.commit() # 构建符合 API 文档的 Payload payload = { "model": model_value, "prompt": data.get('prompt'), "enhance_prompt": data.get('enhance_prompt', False), "images": data.get('images', []), "aspect_ratio": data.get('aspect_ratio', '9:16') } # 使用系统内置的 Key api_key = Config.TRIAL_KEY # 默认使用试用/中转 Key task_id = str(uuid.uuid4()) app = current_app._get_current_object() system_logger.info("用户发起视频生成任务 (积分模式)", model=model_value, cost=cost) threading.Thread( target=process_video_generation, args=(app, user_id, task_id, payload, api_key, cost) ).start() return jsonify({ "task_id": task_id, "message": "视频生成任务已提交,系统正在导演中..." }) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/notifications/latest', methods=['GET']) @login_required def get_latest_notification(): """获取用户最近一条未读的激活通知""" try: user_id = session.get('user_id') latest = SystemNotification.query.filter_by(is_active=True)\ .filter(~SystemNotification.read_by_users.any(id=user_id))\ .order_by(SystemNotification.created_at.desc()).first() if latest: return jsonify({ "id": latest.id, "title": latest.title, "content": latest.content }) return jsonify({"id": None}) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/notifications/read', methods=['POST']) @login_required def mark_notif_read(): """将通知标记为已读""" try: user_id = session.get('user_id') data = request.json notif_id = data.get('id') if not notif_id: return jsonify({"error": "缺少通知 ID"}), 400 notif = db.session.get(SystemNotification, notif_id) user = db.session.get(User, user_id) if notif and user: if user not in notif.read_by_users: notif.read_by_users.append(user) db.session.commit() return jsonify({"status": "ok"}) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/history', methods=['GET']) @login_required def get_history(): """获取用户的历史生成记录 (支持分页,限 90 天内)""" try: from datetime import datetime, timedelta user_id = session.get('user_id') page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 10, type=int) # 计算 90 天前的时间 ninety_days_ago = datetime.utcnow() - timedelta(days=90) pagination = GenerationRecord.query.filter( GenerationRecord.user_id == user_id, GenerationRecord.created_at >= ninety_days_ago, GenerationRecord.prompt != "解读验光单" # 过滤掉验光单助手的操作记录 ).order_by(GenerationRecord.created_at.desc())\ .paginate(page=page, per_page=per_page, error_out=False) # 格式化 URL,兼容新旧数据格式 history_list = [] for r in pagination.items: raw_urls = json.loads(r.image_urls) formatted_urls = [] for u in raw_urls: if isinstance(u, str): # 旧数据:直接返回原图作为缩略图 formatted_urls.append({"url": u, "thumb": u}) else: # 如果是视频类型,提供默认预览图 (此处使用一个公共视频占位图或空) if u.get('type') == 'video' and not u.get('thumb'): u['thumb'] = "https://img.icons8.com/flat-round/64/000000/play--v1.png" formatted_urls.append(u) history_list.append({ "id": r.id, "prompt": r.prompt, "model": r.model, "urls": formatted_urls, "created_at": (r.created_at + timedelta(hours=8)).strftime('%b %d, %H:%M') }) return jsonify({ "history": history_list, "has_next": pagination.has_next, "total": pagination.total }) except Exception as e: return jsonify({"error": str(e)}), 500 @api_bp.route('/api/download_proxy', methods=['GET']) @login_required def download_proxy(): """代理下载远程文件,强制浏览器弹出下载""" url = request.args.get('url') filename = request.args.get('filename', f"video-{int(time.time())}.mp4") if not url: return jsonify({"error": "缺少 URL 参数"}), 400 try: # 流式获取远程文件 req = requests.get(url, stream=True, timeout=60) req.raise_for_status() headers = {} if req.headers.get('Content-Type'): headers['Content-Type'] = req.headers['Content-Type'] else: headers['Content-Type'] = 'application/octet-stream' headers['Content-Disposition'] = f'attachment; filename="{filename}"' def generate(): for chunk in req.iter_content(chunk_size=4096): yield chunk return current_app.response_class(generate(), headers=headers) except Exception as e: system_logger.error(f"代理下载失败: {str(e)}") return jsonify({"error": "下载失败"}), 500