huabu/node_engine.py

265 lines
11 KiB
Python
Raw Permalink Normal View History

2026-02-07 00:17:23 +08:00
import os
import json
import time
import requests
from config import Config
class NodeEngine:
def __init__(self, config_dir='configs/nodes'):
self.config_dir = config_dir
self.ensure_config_dir()
def ensure_config_dir(self):
os.makedirs(self.config_dir, exist_ok=True)
# Always update the default node to reflect new capabilities
self.create_default_node()
def create_default_node(self):
# 1. Nano-banana Generator
nano_config = {
"id": "nano_banana",
"name": "Nano-banana 图片生成",
"type": "generator",
"inputs": [
{"name": "prompt", "label": "提示词 (Prompt)", "ui_widget": "text_area", "data_type": "text"},
{
"name": "aspect_ratio",
"label": "图片比例",
"ui_widget": "select",
"options": ["1:1", "4:3", "3:4", "16:9", "9:16", "3:2", "2:3", "21:9"],
"data_type": "string"
},
{
"name": "model",
"label": "模型版本",
"ui_widget": "select",
"options": ["nano-banana", "nano-banana-hd"],
"data_type": "string"
}
],
"outputs": [
{"name": "image", "label": "生成图像", "data_type": "image"}
]
}
# 2. Image Preview Node
preview_config = {
"id": "image_preview",
"name": "图片预览",
"type": "preview",
"inputs": [
{"name": "image", "label": "输入预览", "ui_widget": "hidden", "data_type": "image"}
],
"outputs": []
}
# 3. Image Upload Node
upload_config = {
"id": "image_upload",
"name": "本地上传",
"type": "input",
"inputs": [
{"name": "file", "label": "选择文件", "ui_widget": "file_upload", "data_type": "file"}
],
"outputs": [
{"name": "image", "label": "输出图像", "data_type": "image"}
]
}
# 4. Text Input Node
text_input_config = {
"id": "text_input",
"name": "文本输入",
"type": "input",
"inputs": [
{"name": "text", "label": "输入文本", "ui_widget": "text_area", "data_type": "text"}
],
"outputs": [
{"name": "text", "label": "输出文本", "data_type": "text"}
]
}
# 5. System Dictionary API Node
dict_api_config = {
"id": "sys_dict_api",
"name": "系统字典接口",
"type": "input",
"inputs": [
{"name": "code", "label": "字典编码 (Code)", "ui_widget": "text_input", "data_type": "string"},
{"name": "api_url", "label": "接口地址", "ui_widget": "hidden", "data_type": "string", "default": "https://nas.4x4g.com:10011/api/common/sys/dict"}
],
"outputs": [
{"name": "options", "label": "字典数据 (Options)", "data_type": "dict"}
]
}
with open(os.path.join(self.config_dir, 'nano_banana.json'), 'w', encoding='utf-8') as f:
json.dump(nano_config, f, indent=4, ensure_ascii=False)
with open(os.path.join(self.config_dir, 'image_preview.json'), 'w', encoding='utf-8') as f:
json.dump(preview_config, f, indent=4, ensure_ascii=False)
with open(os.path.join(self.config_dir, 'image_upload.json'), 'w', encoding='utf-8') as f:
json.dump(upload_config, f, indent=4, ensure_ascii=False)
with open(os.path.join(self.config_dir, 'text_input.json'), 'w', encoding='utf-8') as f:
json.dump(text_input_config, f, indent=4, ensure_ascii=False)
with open(os.path.join(self.config_dir, 'sys_dict_api.json'), 'w', encoding='utf-8') as f:
json.dump(dict_api_config, f, indent=4, ensure_ascii=False)
# Remove old dict_node.json
old_dict = os.path.join(self.config_dir, 'dict_node.json')
if os.path.exists(old_dict):
os.remove(old_dict)
# Remove old sdxl.json if it exists to avoid confusion
old_file = os.path.join(self.config_dir, 'sdxl.json')
if os.path.exists(old_file):
os.remove(old_file)
def get_all_node_configs(self):
configs = []
for filename in os.listdir(self.config_dir):
if filename.endswith('.json'):
try:
with open(os.path.join(self.config_dir, filename), 'r', encoding='utf-8') as f:
configs.append(json.load(f))
except Exception as e:
print(f"Error loading {filename}: {e}")
return configs
def execute_node(self, node_id, data):
"""
Calls the Nano-banana API via proxy prefixing.
"""
# Handle different node types
if "preview" in node_id:
# Preview node just shows the input image
img_url = data.get('uploaded_url') or data.get('image')
if not img_url:
return {"type": "error", "error": "没有可预览的图像数据"}
return {"type": "image", "url": img_url, "time": 0.1}
if "upload" in node_id:
# Upload node returns the uploaded URL
img_url = data.get('uploaded_url')
if not img_url:
return {"type": "error", "error": "请先上传图片"}
return {"type": "image", "url": img_url, "time": 0.1}
if "sys_dict" in node_id:
code = data.get('code', 'aspect_ratio')
api_url = data.get('api_url', 'https://nas.4x4g.com:10011/api/common/sys/dict')
try:
params = {"code": code}
response = requests.get(api_url, params=params, timeout=10)
response.raise_for_status()
res_data = response.json()
# Return the options list as the result
options = res_data.get('data', {}).get('options', [])
return {
"type": "text",
"content": json.dumps(options, indent=4, ensure_ascii=False),
"time": 0.2
}
except Exception as e:
return {"type": "error", "error": f"字典获取失败: {str(e)}"}
prompt = data.get('prompt', 'A beautiful landscape')
aspect_ratio = data.get('aspect_ratio', '1:1')
model = data.get('model', 'nano-banana')
# Prefix the proxy URL as requested: https://proxy.com/https://api.com/...
# Prefix the proxy URL as requested: https://proxy.com/https://api.com/...
# Add async=true query parameter
target_url = f"{Config.BASE_URL.rstrip('/')}/v1/images/generations?async=true"
2026-02-07 00:17:23 +08:00
url = f"{Config.PROXY}{target_url}"
headers = {
"Authorization": f"Bearer {Config.API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"response_format": "url"
}
start_time = time.time()
try:
# 1. Submit Async Task
response = requests.post(url, headers=headers, json=payload, timeout=30)
2026-02-07 00:17:23 +08:00
response.raise_for_status()
res_data = response.json()
task_id = res_data.get('task_id')
if not task_id:
# Fallback to sync if no task_id returned (backward compatibility)
image_url = res_data.get('data', [{}])[0].get('url')
return {"type": "image", "url": image_url, "time": round(time.time() - start_time, 2)}
print(f"Task submitted, ID: {task_id}. Polling for results...")
# 2. Poll for Results
poll_url = f"{Config.PROXY}{Config.BASE_URL.rstrip('/')}/v1/images/tasks/{task_id}"
print(f"Polling URL: {poll_url}")
2026-02-07 00:17:23 +08:00
# Create a session for polling to avoid connection pool issues
with requests.Session() as session:
# GET request doesn't need Content-Type
poll_headers = headers.copy()
poll_headers.pop("Content-Type", None)
max_retries = 30
for i in range(max_retries):
time.sleep(2) # Wait 2s between checks
try:
poll_res = session.get(poll_url, headers=poll_headers, timeout=20, verify=False)
poll_res.raise_for_status()
task_data = poll_res.json()
print(f"DEBUG Task Data: {task_data}") # Debug output
# Handle nested response structure (common in proxy/wrapper APIs)
inner_data = task_data
if 'data' in task_data and isinstance(task_data['data'], dict) and 'status' in task_data['data']:
inner_data = task_data['data']
status = inner_data.get('status') or inner_data.get('state')
print(f"Poll {i+1}/{max_retries}: Status={status}")
if status in ['succeeded', 'SUCCESS']:
# Path: inner_data -> data (dict) -> data (list) -> [0] -> url
result_payload = inner_data.get('data') or inner_data.get('result')
# Handle standard list format
if isinstance(result_payload, dict):
items = result_payload.get('data', [])
if items and len(items) > 0:
image_url = items[0].get('url')
return {"type": "image", "url": image_url, "time": round(time.time() - start_time, 2)}
# Direct url in payload
if result_payload.get('url'):
return {"type": "image", "url": result_payload.get('url'), "time": round(time.time() - start_time, 2)}
elif status in ['failed', 'FAILURE']:
raise Exception(f"Task failed: {inner_data.get('fail_reason') or inner_data.get('error')}")
elif status in ['processing', 'pending', 'QUEUED', 'IN_PROGRESS', None]:
# If None, it might mean the proxy isn't returning the right data yet, wait.
continue
else:
continue
except requests.exceptions.SSLError:
print("SSL Error during polling, retrying...")
continue
except Exception as poll_e:
print(f"Polling error: {poll_e}")
# Don't break immediately on ephemeral network errors
continue
raise Exception("Task timed out after 60 seconds")
2026-02-07 00:17:23 +08:00
except Exception as e:
print(f"API Error: {e}")
return {
"type": "error",
"error": str(e)
}