157 lines
5.7 KiB
Python
157 lines
5.7 KiB
Python
|
from openai import OpenAI
|
|||
|
import os
|
|||
|
import json
|
|||
|
import random
|
|||
|
import pandas as pd
|
|||
|
import re
|
|||
|
import sys
|
|||
|
import sqlite3
|
|||
|
|
|||
|
import time
|
|||
|
|
|||
|
from sql_prompt import create_prompt_json
|
|||
|
from sql_prompt import read_json_file
|
|||
|
|
|||
|
# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载
|
|||
|
os.environ['VLLM_USE_MODELSCOPE']='True'
|
|||
|
if __name__ == "__main__":
|
|||
|
|
|||
|
sys.stdout.reconfigure(encoding='utf-8')
|
|||
|
|
|||
|
# 连接到数据库
|
|||
|
sort_json_path = "sort.json" # 分类信息文件
|
|||
|
category_data = read_json_file(sort_json_path)
|
|||
|
|
|||
|
# 文件路径
|
|||
|
table_name = 'data'
|
|||
|
conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名
|
|||
|
cursor = conn.cursor()
|
|||
|
sort_name = '产品类型'
|
|||
|
sort_value = ('住宿', '门票', '抢购')
|
|||
|
num_limit = "500"
|
|||
|
# SQL查询 - 随机选择500个指定产品类型的记录
|
|||
|
query = f"""
|
|||
|
SELECT * FROM data
|
|||
|
WHERE {sort_name} IN {sort_value}
|
|||
|
ORDER BY RANDOM()
|
|||
|
LIMIT {num_limit}
|
|||
|
"""
|
|||
|
cursor.execute(query)
|
|||
|
result = cursor.fetchall()
|
|||
|
|
|||
|
# 修改 OpenAI 的 API 密钥和 API 基础 URL 以使用 vLLM 的 API 服务器。
|
|||
|
openai_api_key = "EMPTY"
|
|||
|
openai_api_base = "http://localhost:8000/v1"
|
|||
|
client = OpenAI(
|
|||
|
api_key=openai_api_key,
|
|||
|
base_url=openai_api_base,
|
|||
|
)
|
|||
|
|
|||
|
#work_dir = './prompt'
|
|||
|
|
|||
|
results = []
|
|||
|
|
|||
|
for i, row in enumerate(result):
|
|||
|
# 记录开始时间
|
|||
|
loop_start_time = time.time()
|
|||
|
|
|||
|
prompt = create_prompt_json(row ,category_data) # 您的prompt文件路径
|
|||
|
prompt_str = json.dumps(prompt, ensure_ascii=False) # 将字典转换为JSON字符串,保留中文字符
|
|||
|
|
|||
|
# 使用聊天接口
|
|||
|
messages = [
|
|||
|
{"role": "user", "content": prompt_str}
|
|||
|
]
|
|||
|
|
|||
|
completion = client.chat.completions.create(
|
|||
|
model="Qwen3-8B",
|
|||
|
messages=messages
|
|||
|
)
|
|||
|
|
|||
|
# 获取生成的文本
|
|||
|
for choice in completion.choices:
|
|||
|
generated_text = choice.message.content
|
|||
|
|
|||
|
# 确保文本使用UTF-8编码
|
|||
|
if isinstance(generated_text, bytes):
|
|||
|
generated_text = generated_text.decode('utf-8')
|
|||
|
elif isinstance(generated_text, str):
|
|||
|
# 如果是字符串,但编码可能不是UTF-8,先转为bytes再解码
|
|||
|
try:
|
|||
|
generated_text = generated_text.encode('latin-1').decode('utf-8')
|
|||
|
except UnicodeError:
|
|||
|
# 如果上面的转换失败,保持原样
|
|||
|
pass
|
|||
|
|
|||
|
result_entry = []
|
|||
|
result_entry.extend(row) # 将row的所有元素分别添加到result_entry
|
|||
|
|
|||
|
# 提取</think>后的内容作为最终回答
|
|||
|
think_parts = generated_text.split("</think>")
|
|||
|
if len(think_parts) > 1:
|
|||
|
final_answer = think_parts[1].strip()
|
|||
|
# 继续处理的文本改为最终回答部分
|
|||
|
processing_text = final_answer
|
|||
|
else:
|
|||
|
# 如果没有</think>标记,使用原始文本
|
|||
|
processing_text = generated_text
|
|||
|
|
|||
|
# 处理提取出的文本内容
|
|||
|
try:
|
|||
|
# JSON解析失败,使用正则表达式
|
|||
|
patterns = {
|
|||
|
"primary_category": r'primary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
|||
|
"secondary_category": r'secondary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
|||
|
"tertiary_category": r'tertiary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
|||
|
"confidence": r'confidence["\s]*[::]\s*([0-9.]+)',
|
|||
|
"reasoning": r'reasoning["\s]*[::]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)'
|
|||
|
}
|
|||
|
|
|||
|
# 按顺序提取各个字段的值并追加到result_entry列表
|
|||
|
for field, pattern in patterns.items():
|
|||
|
match = re.search(pattern, processing_text, re.IGNORECASE)
|
|||
|
if match:
|
|||
|
# 确保匹配结果使用UTF-8编码
|
|||
|
match_text = match.group(1).strip()
|
|||
|
if isinstance(match_text, bytes):
|
|||
|
match_text = match_text.decode('utf-8')
|
|||
|
elif isinstance(match_text, str):
|
|||
|
try:
|
|||
|
match_text = match_text.encode('latin-1').decode('utf-8', errors='replace')
|
|||
|
except UnicodeError:
|
|||
|
pass
|
|||
|
result_entry.append(match_text)
|
|||
|
else:
|
|||
|
result_entry.append("")
|
|||
|
except Exception as e:
|
|||
|
print(f"处理样本 {i} 时出错: {e}")
|
|||
|
|
|||
|
# 添加处理后的结果
|
|||
|
results.append(result_entry)
|
|||
|
|
|||
|
# 计算并打印本次循环的执行时间
|
|||
|
loop_end_time = time.time()
|
|||
|
loop_duration = loop_end_time - loop_start_time
|
|||
|
print(f"第 {i+1} 次模型推理时间: {loop_duration:.2f} 秒")
|
|||
|
|
|||
|
df = pd.DataFrame(results)
|
|||
|
excel_file = '500_Qwen3_8B_sort.xlsx'
|
|||
|
# 获取数据库表的列名作为DataFrame的表头
|
|||
|
columns = [description[0] for description in cursor.description]
|
|||
|
columns.append('primary_category')
|
|||
|
columns.append('secondary_category')
|
|||
|
columns.append('tertiary_category')
|
|||
|
columns.append('confidence')
|
|||
|
columns.append('reasoning')
|
|||
|
# 添加index列
|
|||
|
df.columns = columns
|
|||
|
|
|||
|
# 将第三列移到第一列
|
|||
|
cols = df.columns.tolist()
|
|||
|
third_col = cols.pop(2) # 第三列索引为2
|
|||
|
cols.insert(0, third_col)
|
|||
|
df = df[cols]
|
|||
|
|
|||
|
|
|||
|
df.to_excel(excel_file, index=False)
|
|||
|
print(f"\n所有生成文本已保存到 {excel_file}")
|