from openai import OpenAI import os import json import random import pandas as pd import re import sys import sqlite3 import time from sql_prompt import create_prompt_json from sql_prompt import read_json_file # 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载 os.environ['VLLM_USE_MODELSCOPE']='True' if __name__ == "__main__": sys.stdout.reconfigure(encoding='utf-8') # 连接到数据库 sort_json_path = "sort.json" # 分类信息文件 category_data = read_json_file(sort_json_path) # 文件路径 table_name = 'data' conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名 cursor = conn.cursor() sort_name = '产品类型' sort_value = ('住宿', '门票', '抢购') num_limit = "500" # SQL查询 - 随机选择500个指定产品类型的记录 query = f""" SELECT * FROM data WHERE {sort_name} IN {sort_value} ORDER BY RANDOM() LIMIT {num_limit} """ cursor.execute(query) result = cursor.fetchall() # 修改 OpenAI 的 API 密钥和 API 基础 URL 以使用 vLLM 的 API 服务器。 openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) #work_dir = './prompt' results = [] for i, row in enumerate(result): # 记录开始时间 loop_start_time = time.time() prompt = create_prompt_json(row ,category_data) # 您的prompt文件路径 prompt_str = json.dumps(prompt, ensure_ascii=False) # 将字典转换为JSON字符串,保留中文字符 # 使用聊天接口 messages = [ {"role": "user", "content": prompt_str} ] completion = client.chat.completions.create( model="Qwen3-8B", messages=messages ) # 获取生成的文本 for choice in completion.choices: generated_text = choice.message.content # 确保文本使用UTF-8编码 if isinstance(generated_text, bytes): generated_text = generated_text.decode('utf-8') elif isinstance(generated_text, str): # 如果是字符串,但编码可能不是UTF-8,先转为bytes再解码 try: generated_text = generated_text.encode('latin-1').decode('utf-8') except UnicodeError: # 如果上面的转换失败,保持原样 pass result_entry = [] result_entry.extend(row) # 将row的所有元素分别添加到result_entry # 提取后的内容作为最终回答 think_parts = generated_text.split("") if len(think_parts) > 1: final_answer = think_parts[1].strip() # 继续处理的文本改为最终回答部分 processing_text = final_answer else: # 如果没有标记,使用原始文本 processing_text = generated_text # 处理提取出的文本内容 try: # JSON解析失败,使用正则表达式 patterns = { "primary_category": r'primary_category["\s]*[::]\s*["]*([^",\n}]+)', "secondary_category": r'secondary_category["\s]*[::]\s*["]*([^",\n}]+)', "tertiary_category": r'tertiary_category["\s]*[::]\s*["]*([^",\n}]+)', "confidence": r'confidence["\s]*[::]\s*([0-9.]+)', "reasoning": r'reasoning["\s]*[::]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)' } # 按顺序提取各个字段的值并追加到result_entry列表 for field, pattern in patterns.items(): match = re.search(pattern, processing_text, re.IGNORECASE) if match: # 确保匹配结果使用UTF-8编码 match_text = match.group(1).strip() if isinstance(match_text, bytes): match_text = match_text.decode('utf-8') elif isinstance(match_text, str): try: match_text = match_text.encode('latin-1').decode('utf-8', errors='replace') except UnicodeError: pass result_entry.append(match_text) else: result_entry.append("") except Exception as e: print(f"处理样本 {i} 时出错: {e}") # 添加处理后的结果 results.append(result_entry) # 计算并打印本次循环的执行时间 loop_end_time = time.time() loop_duration = loop_end_time - loop_start_time print(f"第 {i+1} 次模型推理时间: {loop_duration:.2f} 秒") df = pd.DataFrame(results) excel_file = '500_Qwen3_8B_sort.xlsx' # 获取数据库表的列名作为DataFrame的表头 columns = [description[0] for description in cursor.description] columns.append('primary_category') columns.append('secondary_category') columns.append('tertiary_category') columns.append('confidence') columns.append('reasoning') # 添加index列 df.columns = columns # 将第三列移到第一列 cols = df.columns.tolist() third_col = cols.pop(2) # 第三列索引为2 cols.insert(0, third_col) df = df[cols] df.to_excel(excel_file, index=False) print(f"\n所有生成文本已保存到 {excel_file}")