修改了object目录读取的方式
This commit is contained in:
parent
67722a5c72
commit
4a4b37cba7
@ -8,128 +8,229 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# Add project root to the Python path
|
# --- Setup Project Root Path ---
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.insert(0, project_root)
|
if PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
# --- End Setup ---
|
||||||
|
|
||||||
from core.ai_agent import AI_Agent
|
# --- Imports from the project ---
|
||||||
from utils.prompt_manager import PromptManager
|
try:
|
||||||
from utils.tweet_generator import (
|
from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic, generate_posters_for_topic
|
||||||
run_topic_generation_pipeline,
|
from core.topic_parser import TopicParser
|
||||||
generate_content_for_topic,
|
from utils.prompt_manager import PromptManager # Needed for content generation
|
||||||
generate_posters_for_topic
|
from core.ai_agent import AI_Agent # Needed for content generation
|
||||||
)
|
# from utils.tweet_generator import tweetTopicRecord # No longer needed directly
|
||||||
|
from utils.output_handler import FileSystemOutputHandler, OutputHandler # Import handlers
|
||||||
|
except ImportError as e:
|
||||||
|
logging.critical(f"ImportError: {e}. Ensure all core/utils modules are available and '{PROJECT_ROOT}' is in sys.path.")
|
||||||
|
sys.exit(1)
|
||||||
|
# --- End Imports ---
|
||||||
|
|
||||||
def load_config(config_path="/root/autodl-tmp/TravelContentCreator/poster_gen_config.json"):
|
def load_config(config_path):
|
||||||
"""Loads configuration relative to the script."""
|
"""Loads configuration from a JSON file."""
|
||||||
if not os.path.exists(config_path):
|
|
||||||
logging.error(f"Error: Config file '{config_path}' not found.")
|
|
||||||
sys.exit(1)
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
logging.info("Configuration loaded successfully.")
|
logging.info(f"Config loaded successfully from {config_path}")
|
||||||
return config
|
return config
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error(f"Error: Configuration file not found at {config_path}")
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Error: Could not decode JSON from {config_path}")
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error loading configuration: {e}")
|
logging.exception(f"An unexpected error occurred loading config {config_path}:")
|
||||||
sys.exit(1)
|
return None
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
|
# --- Basic Logging Setup ---
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
)
|
)
|
||||||
|
# --- End Logging Setup ---
|
||||||
|
|
||||||
logging.info("--- Starting Pipeline Step Test ---")
|
logging.info("--- Starting Pipeline Steps Test ---")
|
||||||
config = load_config()
|
|
||||||
|
|
||||||
# --- Override config for faster testing ---
|
# 1. Load Configuration
|
||||||
config['num'] = 1 # Generate only 1 topic
|
config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json") # Use project root path
|
||||||
config['variants'] = 1 # Generate only 1 content/poster variant
|
config = load_config(config_path)
|
||||||
logging.info(f"Config overridden for testing: num={config['num']}, variants={config['variants']}")
|
if config is None:
|
||||||
|
logging.critical("Failed to load configuration. Exiting test.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
run_id = None
|
# --- Initialize Output Handler ---
|
||||||
tweet_topic_record = None
|
output_handler: OutputHandler = FileSystemOutputHandler(config.get("output_dir", "result"))
|
||||||
ai_agent_content = None # Separate agent instance for content/poster
|
logging.info(f"Using Output Handler: {output_handler.__class__.__name__}")
|
||||||
|
# --- End Output Handler Init ---
|
||||||
|
|
||||||
|
# 2. Define a unique Run ID for this test run
|
||||||
|
test_run_id = f"test_pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
logging.info(f"Using Run ID: {test_run_id}")
|
||||||
|
|
||||||
|
# --- Step 1: Topic Generation ---
|
||||||
|
logging.info("\n--- Testing Step 1: Topic Generation ---")
|
||||||
|
step1_start = time.time()
|
||||||
|
# run_topic_generation_pipeline still takes config internally
|
||||||
|
run_id_step1, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, run_id=test_run_id)
|
||||||
|
step1_end = time.time()
|
||||||
|
|
||||||
|
if run_id_step1 is None or topics_list is None:
|
||||||
|
logging.error("Topic generation (Step 1) failed. Cannot proceed.")
|
||||||
|
sys.exit(1)
|
||||||
|
if run_id_step1 != test_run_id:
|
||||||
|
logging.warning(f"run_id returned from step 1 ({run_id_step1}) differs from the one provided ({test_run_id}). Using the returned one.")
|
||||||
|
test_run_id = run_id_step1 # Use the ID actually used by the pipeline
|
||||||
|
|
||||||
|
logging.info(f"Step 1 finished in {step1_end - step1_start:.2f} seconds.")
|
||||||
|
logging.info(f"Generated {len(topics_list)} topics for Run ID: {test_run_id}")
|
||||||
|
# Use output handler to save topic results (mimics main.py)
|
||||||
|
output_handler.handle_topic_results(test_run_id, topics_list, system_prompt, user_prompt)
|
||||||
|
|
||||||
|
# --- Step 2: Content and Poster Generation ---
|
||||||
|
logging.info("\n--- Testing Step 2: Content and Poster Generation ---")
|
||||||
|
step2_start = time.time()
|
||||||
|
|
||||||
|
# Initialize resources needed for step 2, extracting params from config
|
||||||
|
prompt_manager = None
|
||||||
|
ai_agent = None
|
||||||
|
step2_success_flag = False
|
||||||
try:
|
try:
|
||||||
# --- Step 1: Test Topic Generation ---
|
# --- Create PromptManager ---
|
||||||
logging.info("\n--- Testing Topic Generation ---")
|
prompt_manager = PromptManager(
|
||||||
run_id, tweet_topic_record = run_topic_generation_pipeline(config) # run_id generated inside if not passed
|
topic_system_prompt_path=config.get("topic_system_prompt"),
|
||||||
|
topic_user_prompt_path=config.get("topic_user_prompt"),
|
||||||
if not run_id or not tweet_topic_record or not tweet_topic_record.topics_list:
|
content_system_prompt_path=config.get("content_system_prompt"),
|
||||||
logging.info("Topic generation failed or produced no topics. Exiting test.")
|
prompts_dir=config.get("prompts_dir"),
|
||||||
return
|
resource_dir_config=config.get("resource_dir", []),
|
||||||
|
topic_gen_num=config.get("num", 1),
|
||||||
logging.info(f"Topic generation successful. Run ID: {run_id}")
|
topic_gen_date=config.get("date", "")
|
||||||
logging.info(f"Generated {len(tweet_topic_record.topics_list)} topic(s).")
|
)
|
||||||
test_topic = tweet_topic_record.topics_list[0] # Get the first topic for testing
|
logging.info("PromptManager instance created for Step 2 test.")
|
||||||
logging.info("Test Topic Data:", json.dumps(test_topic, ensure_ascii=False, indent=2))
|
|
||||||
|
# --- Create AI Agent ---
|
||||||
# --- Step 2: Test Content Generation (for the first topic) ---
|
ai_api_url = config.get("api_url")
|
||||||
logging.info("\n--- Testing Content Generation ---")
|
ai_model = config.get("model")
|
||||||
|
ai_api_key = config.get("api_key")
|
||||||
# Initialize resources needed for content generation
|
|
||||||
prompt_manager = PromptManager(config)
|
|
||||||
logging.info("Initializing AI Agent for content...")
|
|
||||||
request_timeout = config.get("request_timeout", 30)
|
request_timeout = config.get("request_timeout", 30)
|
||||||
max_retries = config.get("max_retries", 3)
|
max_retries = config.get("max_retries", 3)
|
||||||
ai_agent_content = AI_Agent(
|
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||||
config["api_url"],
|
raise ValueError("Missing required AI configuration (api_url, model, api_key)")
|
||||||
config["model"],
|
logging.info("Initializing AI Agent for content generation test...")
|
||||||
config["api_key"],
|
ai_agent = AI_Agent(
|
||||||
|
base_url=ai_api_url,
|
||||||
|
model_name=ai_model,
|
||||||
|
api=ai_api_key,
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
|
||||||
base_output_dir = config["output_dir"]
|
total_topics = len(topics_list)
|
||||||
topic_index = 1 # Testing the first topic (1-based index)
|
logging.info(f"Processing {total_topics} topics for content/posters...")
|
||||||
|
for i, topic_item in enumerate(topics_list):
|
||||||
|
topic_index = topic_item.get('index', i + 1)
|
||||||
|
logging.info(f"--- Processing Topic {topic_index}/{total_topics} ---")
|
||||||
|
|
||||||
tweet_content_list = generate_content_for_topic(
|
# --- Generate Content ---
|
||||||
ai_agent_content, prompt_manager, config, test_topic,
|
content_variants = config.get("variants", 1)
|
||||||
base_output_dir, run_id, topic_index
|
content_temp = config.get("content_temperature", 0.3)
|
||||||
)
|
content_top_p = config.get("content_top_p", 0.4)
|
||||||
|
content_presence_penalty = config.get("content_presence_penalty", 1.5)
|
||||||
if not tweet_content_list:
|
|
||||||
logging.info("Content generation failed or produced no content. Exiting test.")
|
content_success = generate_content_for_topic(
|
||||||
return
|
ai_agent=ai_agent,
|
||||||
|
prompt_manager=prompt_manager,
|
||||||
logging.info(f"Content generation successful. Generated {len(tweet_content_list)} variant(s).")
|
topic_item=topic_item,
|
||||||
logging.info("Generated Content Data (first variant):", json.dumps(tweet_content_list[0], ensure_ascii=False, indent=2))
|
run_id=test_run_id,
|
||||||
|
topic_index=topic_index,
|
||||||
|
output_handler=output_handler,
|
||||||
# --- Step 3: Test Poster Generation (for the first topic/content) ---
|
variants=content_variants,
|
||||||
logging.info("\n--- Testing Poster Generation ---")
|
temperature=content_temp,
|
||||||
|
top_p=content_top_p,
|
||||||
# Poster generation uses its own internal ContentGenerator and PosterGenerator instances
|
presence_penalty=content_presence_penalty
|
||||||
# We just need to call the function
|
)
|
||||||
success = generate_posters_for_topic(
|
|
||||||
config,
|
|
||||||
test_topic,
|
|
||||||
tweet_content_list, # Pass the list generated above
|
|
||||||
base_output_dir,
|
|
||||||
run_id,
|
|
||||||
topic_index
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logging.info("Poster generation function executed (check output directory for results).")
|
|
||||||
else:
|
|
||||||
logging.info("Poster generation function reported failure or skipped execution.")
|
|
||||||
|
|
||||||
|
if content_success:
|
||||||
|
logging.info(f"Content generation successful for Topic {topic_index}.")
|
||||||
|
# --- Generate Posters ---
|
||||||
|
poster_variants = config.get("variants", 1)
|
||||||
|
poster_assets_dir = config.get("poster_assets_base_dir")
|
||||||
|
img_base_dir = config.get("image_base_dir")
|
||||||
|
mod_img_subdir = config.get("modify_image_subdir", "modify")
|
||||||
|
res_dir_config = config.get("resource_dir", [])
|
||||||
|
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
|
||||||
|
txt_possibility = config.get("text_possibility", 0.3)
|
||||||
|
collage_subdir = config.get("output_collage_subdir", "collage_img")
|
||||||
|
poster_subdir = config.get("output_poster_subdir", "poster")
|
||||||
|
poster_filename = config.get("output_poster_filename", "poster.jpg")
|
||||||
|
cam_img_subdir = config.get("camera_image_subdir", "相机")
|
||||||
|
|
||||||
|
if not poster_assets_dir or not img_base_dir:
|
||||||
|
logging.error(f"Missing critical paths for poster generation. Skipping posters for topic {topic_index}.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
posters_attempted = generate_posters_for_topic(
|
||||||
|
topic_item=topic_item,
|
||||||
|
output_dir=config.get("output_dir", "result"),
|
||||||
|
run_id=test_run_id,
|
||||||
|
topic_index=topic_index,
|
||||||
|
output_handler=output_handler,
|
||||||
|
variants=poster_variants,
|
||||||
|
poster_assets_base_dir=poster_assets_dir,
|
||||||
|
image_base_dir=img_base_dir,
|
||||||
|
modify_image_subdir=mod_img_subdir,
|
||||||
|
resource_dir_config=res_dir_config,
|
||||||
|
poster_target_size=poster_size,
|
||||||
|
text_possibility=txt_possibility,
|
||||||
|
output_collage_subdir=collage_subdir,
|
||||||
|
output_poster_subdir=poster_subdir,
|
||||||
|
output_poster_filename=poster_filename,
|
||||||
|
camera_image_subdir=cam_img_subdir
|
||||||
|
)
|
||||||
|
if posters_attempted:
|
||||||
|
logging.info(f"Poster generation process completed for Topic {topic_index}.")
|
||||||
|
step2_success_flag = True
|
||||||
|
else:
|
||||||
|
logging.warning(f"Poster generation skipped/failed for Topic {topic_index}.")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Content generation failed for Topic {topic_index}. Skipping posters.")
|
||||||
|
logging.info(f"--- Finished Topic {topic_index} ---")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"Configuration error during Step 2 setup: {e}")
|
||||||
|
step2_success_flag = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"\n--- An error occurred during testing ---")
|
logging.exception("An error occurred during Step 2 processing:")
|
||||||
logging.error(f"Error: {e}")
|
step2_success_flag = False # Ensure flag is false on error
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up the content generation AI agent if it was created
|
if ai_agent:
|
||||||
if ai_agent_content:
|
logging.info("Closing AI Agent for content generation test...")
|
||||||
logging.info("\nClosing content generation AI Agent...")
|
ai_agent.close()
|
||||||
ai_agent_content.close()
|
# --- End Simulated Step 2 Logic ---
|
||||||
logging.info("\n--- Test Finished ---")
|
|
||||||
|
|
||||||
|
step2_end = time.time()
|
||||||
|
if step2_success_flag:
|
||||||
|
logging.info(f"Step 2 finished in {step2_end - step2_start:.2f} seconds.")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Step 2 finished in {step2_end - step2_start:.2f} seconds, but encountered errors or generated no output.")
|
||||||
|
|
||||||
|
# --- Finalize Output ---
|
||||||
|
if test_run_id:
|
||||||
|
output_handler.finalize(test_run_id)
|
||||||
|
# --- End Finalize ---
|
||||||
|
|
||||||
|
# --- Test Summary ---
|
||||||
|
logging.info("\n--- Pipeline Steps Test Summary ---")
|
||||||
|
logging.info(f"Run ID: {test_run_id}")
|
||||||
|
output_location = os.path.join(config["output_dir"], test_run_id)
|
||||||
|
logging.info(f"Check output files in: {output_location}")
|
||||||
|
if os.path.exists(output_location):
|
||||||
|
logging.info("Output directory exists.")
|
||||||
|
else:
|
||||||
|
logging.warning("Output directory NOT found.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main_test()
|
main_test()
|
||||||
@ -7,101 +7,119 @@ import json
|
|||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# Add project root to the Python path to allow importing modules from core and utils
|
# Determine the project root directory (assuming examples/ is one level down)
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.insert(0, project_root)
|
if PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
from core.ai_agent import AI_Agent
|
# Now import from core
|
||||||
|
try:
|
||||||
|
from core.ai_agent import AI_Agent
|
||||||
|
except ImportError as e:
|
||||||
|
logging.critical(f"Failed to import AI_Agent. Ensure '{PROJECT_ROOT}' is in sys.path and core/ai_agent.py exists. Error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
def load_config(config_path="../poster_gen_config.json"):
|
def load_config(config_path):
|
||||||
"""Loads configuration from a JSON file relative to this script."""
|
"""Loads configuration from a JSON file."""
|
||||||
if not os.path.exists(config_path):
|
|
||||||
logging.error(f"Error: Configuration file '{config_path}' not found.")
|
|
||||||
logging.error("Make sure you have copied 'example_config.json' to 'poster_gen_config.json' in the project root.")
|
|
||||||
sys.exit(1)
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
# Basic validation can be added here if needed
|
logging.info(f"Config loaded successfully from {config_path}")
|
||||||
logging.info(f"Configuration loaded successfully from {config_path}")
|
|
||||||
return config
|
return config
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error(f"Error: Configuration file not found at {config_path}")
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Error: Could not decode JSON from {config_path}")
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error loading configuration from '{config_path}': {e}")
|
logging.exception(f"An unexpected error occurred loading config {config_path}:")
|
||||||
sys.exit(1)
|
return None
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# --- Basic Logging Setup ---
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
)
|
)
|
||||||
|
# --- End Logging Setup ---
|
||||||
|
|
||||||
logging.info("--- Testing AI Agent Streaming ---")
|
logging.info("Starting AI Agent Stream Test...")
|
||||||
|
|
||||||
# 1. Load configuration
|
|
||||||
config = load_config()
|
|
||||||
|
|
||||||
# 2. Define example prompts (replace with your desired test prompts)
|
# Load configuration (adjust path relative to this script)
|
||||||
test_system_prompt = "You are a helpful assistant. Respond concisely."
|
config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json")
|
||||||
test_user_prompt = "Tell me a short story about a traveling robot."
|
config = load_config(config_path)
|
||||||
|
if config is None:
|
||||||
# You can optionally specify a folder with reference files
|
logging.critical("Failed to load configuration. Exiting test.")
|
||||||
test_file_folder = None # Or e.g., "../resource/Object"
|
sys.exit(1)
|
||||||
|
|
||||||
# Get generation parameters from config or use defaults
|
# Example Prompts
|
||||||
temperature = config.get("content_temperature", 0.7) # Using content params as example
|
system_prompt = "你是一个乐于助人的AI助手,擅长写短篇故事。"
|
||||||
top_p = config.get("content_top_p", 0.9)
|
user_prompt = "请写一个关于旅行机器人的短篇故事,它在一个充满异国情调的星球上发现了新的生命形式。"
|
||||||
presence_penalty = config.get("content_presence_penalty", 1.0)
|
|
||||||
|
|
||||||
# 3. Initialize AI Agent
|
|
||||||
ai_agent = None
|
ai_agent = None
|
||||||
try:
|
try:
|
||||||
|
# --- Extract AI Agent parameters from config ---
|
||||||
|
ai_api_url = config.get("api_url")
|
||||||
|
ai_model = config.get("model")
|
||||||
|
ai_api_key = config.get("api_key")
|
||||||
request_timeout = config.get("request_timeout", 30)
|
request_timeout = config.get("request_timeout", 30)
|
||||||
max_retries = config.get("max_retries", 3)
|
max_retries = config.get("max_retries", 3)
|
||||||
|
|
||||||
|
# Check for required AI params
|
||||||
|
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||||
|
logging.critical("Missing required AI configuration (api_url, model, api_key) in config. Exiting test.")
|
||||||
|
sys.exit(1)
|
||||||
|
# --- End Extract AI Agent params ---
|
||||||
|
|
||||||
|
logging.info("Initializing AI Agent for stream test...")
|
||||||
|
# Initialize AI_Agent using extracted parameters
|
||||||
ai_agent = AI_Agent(
|
ai_agent = AI_Agent(
|
||||||
config["api_url"],
|
api_url=ai_api_url, # Use extracted var
|
||||||
config["model"],
|
model=ai_model, # Use extracted var
|
||||||
config["api_key"],
|
api_key=ai_api_key, # Use extracted var
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
logging.info("AI Agent initialized.")
|
|
||||||
|
# Example call to work_stream
|
||||||
# 4. Call work_stream and process the generator
|
logging.info("Calling ai_agent.work_stream...")
|
||||||
logging.info("\n--- Starting stream generation ---")
|
# Extract generation parameters from config
|
||||||
|
temperature = config.get("content_temperature", 0.7) # Use a relevant temperature setting
|
||||||
|
top_p = config.get("content_top_p", 0.9)
|
||||||
|
presence_penalty = config.get("content_presence_penalty", 0.0)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
stream_generator = ai_agent.work_stream(
|
stream_generator = ai_agent.work_stream(
|
||||||
test_system_prompt,
|
system_prompt=system_prompt,
|
||||||
test_user_prompt,
|
user_prompt=user_prompt,
|
||||||
test_file_folder,
|
info_directory=None, # No extra context folder for this test
|
||||||
temperature,
|
temperature=temperature,
|
||||||
top_p,
|
top_p=top_p,
|
||||||
presence_penalty
|
presence_penalty=presence_penalty
|
||||||
)
|
)
|
||||||
|
|
||||||
full_response_streamed = ""
|
# Process the stream
|
||||||
try:
|
logging.info("Processing stream response:")
|
||||||
for chunk in stream_generator:
|
full_response = ""
|
||||||
print(chunk, end="", flush=True) # Print each chunk as it arrives
|
for chunk in stream_generator:
|
||||||
full_response_streamed += chunk
|
print(chunk, end="", flush=True) # Keep print for stream output
|
||||||
except Exception as e:
|
full_response += chunk
|
||||||
logging.error(f"\nError while iterating through stream generator: {e}")
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logging.info(f"\n--- Stream finished in {end_time - start_time:.2f} seconds ---")
|
logging.info(f"\n--- Stream Finished ---")
|
||||||
# print(f"Full response received via stream:\n{full_response_streamed}") # Optionally print the assembled response
|
logging.info(f"Total time: {end_time - start_time:.2f} seconds")
|
||||||
|
logging.info(f"Total characters received: {len(full_response)}")
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
logging.error(f"Configuration error: Missing key '{e}'. Please check '{config_path}'.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"\nAn error occurred: {e}")
|
logging.exception("An error occurred during the stream test:")
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 5. Close the agent
|
# Ensure the agent is closed
|
||||||
if ai_agent:
|
if ai_agent:
|
||||||
logging.info("\nClosing AI Agent...")
|
logging.info("Closing AI Agent...")
|
||||||
ai_agent.close()
|
ai_agent.close()
|
||||||
logging.info("AI Agent closed.")
|
logging.info("AI Agent closed.")
|
||||||
|
|
||||||
|
|||||||
18
main.py
18
main.py
@ -71,7 +71,23 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle
|
|||||||
logging.info(f"Processing {len(topics_list)} topics...")
|
logging.info(f"Processing {len(topics_list)} topics...")
|
||||||
|
|
||||||
success_flag = False
|
success_flag = False
|
||||||
prompt_manager = PromptManager(config)
|
# --- 创建 PromptManager 实例 (传入具体参数) ---
|
||||||
|
try:
|
||||||
|
prompt_manager = PromptManager(
|
||||||
|
topic_system_prompt_path=config.get("topic_system_prompt"),
|
||||||
|
topic_user_prompt_path=config.get("topic_user_prompt"),
|
||||||
|
content_system_prompt_path=config.get("content_system_prompt"),
|
||||||
|
prompts_dir=config.get("prompts_dir"),
|
||||||
|
resource_dir_config=config.get("resource_dir", []),
|
||||||
|
topic_gen_num=config.get("num", 1), # Topic gen num/date used by topic prompts
|
||||||
|
topic_gen_date=config.get("date", "")
|
||||||
|
)
|
||||||
|
logging.info("PromptManager instance created for Step 2.")
|
||||||
|
except KeyError as e:
|
||||||
|
logging.error(f"Configuration error creating PromptManager: Missing key '{e}'. Cannot proceed with content generation.")
|
||||||
|
return False
|
||||||
|
# --- 结束创建 PromptManager ---
|
||||||
|
|
||||||
ai_agent = None
|
ai_agent = None
|
||||||
try:
|
try:
|
||||||
# --- Initialize AI Agent for Content Generation ---
|
# --- Initialize AI Agent for Content Generation ---
|
||||||
|
|||||||
Binary file not shown.
@ -137,36 +137,60 @@ class PromptManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Error processing Demand description:")
|
logging.exception("Error processing Demand description:")
|
||||||
|
|
||||||
# 2. 添加Object信息 (based on topic_item['object'])
|
# 2. Object Info - 先列出所有可用文件,再注入匹配文件的内容
|
||||||
try:
|
try:
|
||||||
object_name_base = topic_item['object'] # This might be '景点信息-XXX.txt'
|
object_name_from_topic = topic_item.get('object') # e.g., "尚书第建筑群"
|
||||||
object_file_path = None
|
object_file_basenames = []
|
||||||
# Find the full path for the object file from config
|
matched_object_file_path = None
|
||||||
|
matched_object_basename = None
|
||||||
|
|
||||||
|
# 遍历查找 Object 文件
|
||||||
for dir_info in resource_dir_config:
|
for dir_info in resource_dir_config:
|
||||||
if dir_info.get("type") == "Object":
|
if dir_info.get("type") == "Object":
|
||||||
for file_path in dir_info.get("file_path", []):
|
for file_path in dir_info.get("file_path", []):
|
||||||
# Match basename, assuming topic_item['object'] is the basename
|
basename = os.path.basename(file_path)
|
||||||
# if os.path.basename(file_path) == object_name_base:
|
object_file_basenames.append(basename)
|
||||||
# Use containment check instead of exact match
|
|
||||||
if object_name_base in os.path.basename(file_path):
|
# 尝试匹配当前 topic 的 object (仅当尚未找到匹配时)
|
||||||
object_file_path = file_path
|
if object_name_from_topic and not matched_object_file_path:
|
||||||
break
|
cleaned_resource_name = basename
|
||||||
if object_file_path: break
|
if cleaned_resource_name.startswith("景点信息-"):
|
||||||
|
cleaned_resource_name = cleaned_resource_name[len("景点信息-"):]
|
||||||
if object_file_path:
|
if cleaned_resource_name.endswith(".txt"):
|
||||||
object_content = ResourceLoader.load_file_content(object_file_path)
|
cleaned_resource_name = cleaned_resource_name[:-len(".txt")]
|
||||||
if object_content:
|
|
||||||
user_prompt += f"Object Info:\n{object_content}\n"
|
if cleaned_resource_name and cleaned_resource_name in object_name_from_topic:
|
||||||
else:
|
matched_object_file_path = file_path
|
||||||
logging.warning(f"Object file could not be loaded: {object_file_path}")
|
matched_object_basename = basename
|
||||||
|
# 注意:这里不 break,继续收集所有文件名
|
||||||
|
|
||||||
|
# 构建提示词 - Part 1: 文件列表
|
||||||
|
if object_file_basenames:
|
||||||
|
user_prompt += "Object信息:\n"
|
||||||
|
# user_prompt += f"{object_file_basenames}\n\n" # 直接打印列表可能不够清晰
|
||||||
|
for fname in object_file_basenames:
|
||||||
|
user_prompt += f"- {fname}\n"
|
||||||
|
user_prompt += "\n" # 加一个空行
|
||||||
|
logging.info(f"Listed {len(object_file_basenames)} available object files.")
|
||||||
else:
|
else:
|
||||||
# If basename match fails, maybe topic_item['object'] is just 'XXX'?
|
logging.warning("No resource directory entry found with type 'Object', or it has no file paths.")
|
||||||
# Try finding based on substring? This might be ambiguous.
|
|
||||||
logging.warning(f"Object file path not found in config matching object: {topic_item.get('object')}")
|
# 构建提示词 - Part 2: 注入匹配文件内容
|
||||||
except KeyError:
|
if matched_object_file_path:
|
||||||
logging.warning("Warning: 'object' key missing in topic_item for Object prompt.")
|
logging.info(f"Attempting to load content for matched object file: {matched_object_basename}")
|
||||||
|
matched_object_content = ResourceLoader.load_file_content(matched_object_file_path)
|
||||||
|
if matched_object_content:
|
||||||
|
user_prompt += f"{matched_object_basename}:\n{matched_object_content}\n\n"
|
||||||
|
logging.info(f"Successfully loaded and injected content for: {matched_object_basename}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Object file matched ({matched_object_basename}) but could not be loaded or is empty.")
|
||||||
|
elif object_name_from_topic: # 只有当 topic 中指定了 object 但没找到匹配文件时才警告
|
||||||
|
logging.warning(f"Could not find a matching Object resource file to inject content for '{object_name_from_topic}'. Only the list of files was provided.")
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
logging.warning("Warning: 'object' key potentially missing in topic_item.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Error processing Object prompt:")
|
logging.exception("Error processing Object prompt section:")
|
||||||
|
|
||||||
# 3. 添加Product信息 (if applicable)
|
# 3. 添加Product信息 (if applicable)
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user