129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import logging
|
||
|
||
# Determine the project root directory (assuming examples/ is one level down)
|
||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
if PROJECT_ROOT not in sys.path:
|
||
sys.path.append(PROJECT_ROOT)
|
||
|
||
# Now import from core
|
||
try:
|
||
from core.ai_agent import AI_Agent
|
||
except ImportError as e:
|
||
logging.critical(f"Failed to import AI_Agent. Ensure '{PROJECT_ROOT}' is in sys.path and core/ai_agent.py exists. Error: {e}")
|
||
sys.exit(1)
|
||
|
||
def load_config(config_path):
|
||
"""Loads configuration from a JSON file."""
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
logging.info(f"Config loaded successfully from {config_path}")
|
||
return config
|
||
except FileNotFoundError:
|
||
logging.error(f"Error: Configuration file not found at {config_path}")
|
||
return None
|
||
except json.JSONDecodeError:
|
||
logging.error(f"Error: Could not decode JSON from {config_path}")
|
||
return None
|
||
except Exception as e:
|
||
logging.exception(f"An unexpected error occurred loading config {config_path}:")
|
||
return None
|
||
|
||
def main():
|
||
# --- Basic Logging Setup ---
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
# --- End Logging Setup ---
|
||
|
||
logging.info("Starting AI Agent Stream Test...")
|
||
|
||
# Load configuration (adjust path relative to this script)
|
||
config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json")
|
||
config = load_config(config_path)
|
||
if config is None:
|
||
logging.critical("Failed to load configuration. Exiting test.")
|
||
sys.exit(1)
|
||
|
||
# Example Prompts
|
||
system_prompt = "你是一个乐于助人的AI助手,擅长写短篇故事。"
|
||
user_prompt = "请写一个关于旅行机器人的短篇故事,它在一个充满异国情调的星球上发现了新的生命形式。"
|
||
|
||
ai_agent = None
|
||
try:
|
||
# --- Extract AI Agent parameters from config ---
|
||
ai_api_url = config.get("api_url")
|
||
ai_model = config.get("model")
|
||
ai_api_key = config.get("api_key")
|
||
request_timeout = config.get("request_timeout", 30)
|
||
max_retries = config.get("max_retries", 3)
|
||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout
|
||
|
||
# Check for required AI params
|
||
if not all([ai_api_url, ai_model, ai_api_key]):
|
||
logging.critical("Missing required AI configuration (api_url, model, api_key) in config. Exiting test.")
|
||
sys.exit(1)
|
||
# --- End Extract AI Agent params ---
|
||
|
||
logging.info("Initializing AI Agent for stream test...")
|
||
# Initialize AI_Agent using extracted parameters
|
||
ai_agent = AI_Agent(
|
||
api_url=ai_api_url, # Use extracted var
|
||
model=ai_model, # Use extracted var
|
||
api_key=ai_api_key, # Use extracted var
|
||
timeout=request_timeout,
|
||
max_retries=max_retries,
|
||
stream_chunk_timeout=stream_chunk_timeout # Pass it here
|
||
)
|
||
|
||
# Example call to work_stream
|
||
logging.info("Calling ai_agent.work_stream...")
|
||
# Extract generation parameters from config
|
||
temperature = config.get("content_temperature", 0.7) # Use a relevant temperature setting
|
||
top_p = config.get("content_top_p", 0.9)
|
||
presence_penalty = config.get("content_presence_penalty", 0.0)
|
||
|
||
start_time = time.time()
|
||
stream_generator = ai_agent.work_stream(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
info_directory=None, # No extra context folder for this test
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
presence_penalty=presence_penalty
|
||
)
|
||
|
||
# Process the stream
|
||
logging.info("Processing stream response:")
|
||
full_response = ""
|
||
for chunk in stream_generator:
|
||
print(chunk, end="", flush=True) # Keep print for stream output
|
||
full_response += chunk
|
||
|
||
end_time = time.time()
|
||
logging.info(f"\n--- Stream Finished ---")
|
||
logging.info(f"Total time: {end_time - start_time:.2f} seconds")
|
||
logging.info(f"Total characters received: {len(full_response)}")
|
||
|
||
except KeyError as e:
|
||
logging.error(f"Configuration error: Missing key '{e}'. Please check '{config_path}'.")
|
||
except Exception as e:
|
||
logging.exception("An error occurred during the stream test:")
|
||
finally:
|
||
# Ensure the agent is closed
|
||
if ai_agent:
|
||
logging.info("Closing AI Agent...")
|
||
ai_agent.close()
|
||
logging.info("AI Agent closed.")
|
||
|
||
if __name__ == "__main__":
|
||
main() |