260 lines
7.6 KiB
Python
260 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
|
# @skill: qwen-image-generation
|
|
|
|
"""
|
|
Qwen Image Generation Script
|
|
Generate images using Qwen (DashScope) API
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
import requests
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate images using Qwen (DashScope) API",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--api-key",
|
|
type=str,
|
|
required=True,
|
|
help="DashScope API key (can also be set via DASHSCOPE_API_KEY environment variable)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
required=True,
|
|
help="Image generation prompt (max 800 chars)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--size",
|
|
type=str,
|
|
default="1024*1024",
|
|
help="Image resolution (default: 1024*1024, options: 1344*768, 768*1344, 1184*864, 864*1184)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--n",
|
|
type=int,
|
|
default=1,
|
|
choices=range(1, 7),
|
|
help="Number of images to generate (default: 1, max: 6)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--negative-prompt",
|
|
type=str,
|
|
default=None,
|
|
help="Negative prompt to avoid certain elements (max 500 chars)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--prompt-extend",
|
|
type=lambda x: x.lower() == "true",
|
|
default=True,
|
|
help="Enable prompt extend to enhance the prompt (default: true)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--image-url",
|
|
type=str,
|
|
default=None,
|
|
help="Reference image URL for image-to-image generation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output-path",
|
|
type=str,
|
|
default=None,
|
|
help="Local path to save the generated image"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--api-base",
|
|
type=str,
|
|
default="https://dashscope.aliyuncs.com",
|
|
help="API base URL (default: https://dashscope.aliyuncs.com)"
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def download_image(url: str, output_path: str) -> bool:
|
|
"""Download image to local file"""
|
|
try:
|
|
response = requests.get(url, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
with open(output_path, "wb") as f:
|
|
f.write(response.content)
|
|
|
|
print(f" [OK] Saved: {output_path}")
|
|
return True
|
|
except Exception as e:
|
|
print(f" [FAIL] Download failed: {e}")
|
|
return False
|
|
|
|
|
|
def generate_images(args):
|
|
"""Call Qwen (DashScope) API to generate images"""
|
|
url = f"{args.api_base}/api/v1/services/aigc/multimodal-generation/generation"
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {args.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
# Build content array
|
|
content = [{"text": args.prompt}]
|
|
|
|
# Add reference image if provided
|
|
if args.image_url:
|
|
content.append({"image_url": {"url": args.image_url}})
|
|
|
|
# Build parameters dict
|
|
parameters = {
|
|
"prompt_extend": args.prompt_extend,
|
|
"size": args.size,
|
|
"n": args.n
|
|
}
|
|
|
|
# Only add negative_prompt if provided
|
|
if args.negative_prompt:
|
|
parameters["negative_prompt"] = args.negative_prompt
|
|
|
|
# Build request payload
|
|
payload = {
|
|
"model": "qwen-image-2.0-pro",
|
|
"input": {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": content
|
|
}
|
|
]
|
|
},
|
|
"parameters": parameters
|
|
}
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"Qwen Image Generation")
|
|
print(f"{'='*60}")
|
|
print(f"Model: qwen-image-2.0-pro")
|
|
print(f"Prompt: {args.prompt}")
|
|
print(f"Size: {args.size}")
|
|
print(f"Number: {args.n}")
|
|
print(f"Prompt Extend: {'Enabled' if args.prompt_extend else 'Disabled'}")
|
|
if args.negative_prompt:
|
|
print(f"Negative Prompt: {args.negative_prompt}")
|
|
if args.image_url:
|
|
print(f"Reference Image: {args.image_url}")
|
|
print(f"{'='*60}\n")
|
|
|
|
try:
|
|
print("Generating images...")
|
|
response = requests.post(url, headers=headers, json=payload, timeout=180)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
# Check for API errors
|
|
if "error" in result:
|
|
error_msg = result.get("error", {}).get("message", "Unknown error")
|
|
print(f"API Error: {error_msg}")
|
|
return False
|
|
|
|
# Parse response
|
|
choices = result.get("output", {}).get("choices", [])
|
|
usage = result.get("usage", {})
|
|
|
|
# Extract image URLs
|
|
image_urls = []
|
|
for choice in choices:
|
|
message = choice.get("message", {})
|
|
content_items = message.get("content", [])
|
|
for item in content_items:
|
|
if "image" in item:
|
|
image_urls.append(item["image"])
|
|
|
|
width = usage.get("width", 1024)
|
|
height = usage.get("height", 1024)
|
|
request_id = result.get("request_id", "N/A")
|
|
|
|
print(f"\nSuccessfully generated {len(image_urls)} image(s) ({width}x{height})")
|
|
print(f"Request ID: {request_id}\n")
|
|
|
|
saved_count = 0
|
|
|
|
# If output_path is provided, save all images
|
|
if args.output_path:
|
|
timestamp = int(time.time())
|
|
for i, img_url in enumerate(image_urls, 1):
|
|
# Determine file extension from URL
|
|
parsed = urlparse(img_url)
|
|
ext = os.path.splitext(parsed.path)[1] if "." in parsed.path else ".png"
|
|
if not ext or len(ext) > 5:
|
|
ext = ".png"
|
|
|
|
# Handle multiple images
|
|
if len(image_urls) > 1:
|
|
base_path = args.output_path.rsplit('.', 1)[0] if '.' in args.output_path else args.output_path
|
|
ext = args.output_path.rsplit('.', 1)[1] if '.' in args.output_path else ext
|
|
output_path = f"{base_path}_{i}_{timestamp}.{ext}"
|
|
else:
|
|
if not args.output_path.endswith(ext):
|
|
output_path = f"{args.output_path}{ext}"
|
|
else:
|
|
output_path = args.output_path
|
|
|
|
# Ensure directory exists
|
|
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
|
|
|
|
if download_image(img_url, output_path):
|
|
saved_count += 1
|
|
else:
|
|
# Print URLs
|
|
print("Image URLs:")
|
|
for i, img_url in enumerate(image_urls, 1):
|
|
print(f" {i}. {img_url}")
|
|
|
|
print(f"\n{'='*60}")
|
|
if args.output_path:
|
|
print(f"Done! Successfully saved {saved_count}/{len(image_urls)} images")
|
|
print(f"{'='*60}\n")
|
|
|
|
return saved_count > 0 or (len(image_urls) > 0 and not args.output_path)
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"\nRequest Error: {e}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"\nUnexpected Error: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Main function"""
|
|
args = parse_args()
|
|
|
|
# Get API key from argument or environment variable
|
|
if not args.api_key:
|
|
args.api_key = os.environ.get("DASHSCOPE_API_KEY", "")
|
|
|
|
# If still no API key, prompt user to enter it
|
|
if not args.api_key:
|
|
print("Error: API key is required (--api-key or DASHSCOPE_API_KEY)")
|
|
else:
|
|
generate_images(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|