初始提交
This commit is contained in:
commit
8bfc191771
20
sort.json
Normal file
20
sort.json
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"游玩类": {
|
||||
"景点": ["红色景区", "观光街区", "地标建筑", "自然景观", "人文古迹", "古镇/古村"],
|
||||
"动植物园": ["动物园", "水族馆", "植物园"],
|
||||
"旅游项目": ["水上体验", "城市/高空观光", "温泉", "滑雪场", "露营地"],
|
||||
"泛主题乐园": ["游乐园", "影视基地", "水上乐园"],
|
||||
"展馆/展览": ["博物馆", "展览馆", "科技馆", "美术馆", "艺术馆", "纪念馆"],
|
||||
"公园广场": ["公园广场"],
|
||||
"其他游玩": ["其他游玩"]
|
||||
},
|
||||
"住宿类": {
|
||||
"酒店宾馆": ["经济型酒店", "舒适性酒店", "高档型酒店", "豪华型酒店"],
|
||||
"客栈民宿": ["客栈民宿"],
|
||||
"其他住宿": ["其他住宿"]
|
||||
},
|
||||
"旅行社类": {
|
||||
"境外旅行社": ["境外旅行社"],
|
||||
"境内旅行社": ["境内旅行社"]
|
||||
}
|
||||
}
|
90
sql_prompt.py
Normal file
90
sql_prompt.py
Normal file
@ -0,0 +1,90 @@
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
|
||||
# 读取1.json文件
|
||||
def read_json_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"读取文件 {file_path} 时出错: {e}")
|
||||
return {}
|
||||
|
||||
# 主要处理函数
|
||||
def create_prompt_json(data, category_data):
|
||||
|
||||
item = data
|
||||
category = category_data
|
||||
# 创建包含提示词和两个文件内容的JSON结构
|
||||
prompt_data = {
|
||||
"instruction": "将基本信息根据分类体系进行分类判断",
|
||||
"task_description": "你是一个精准的内容分类专家。请阅读提供的基本信息(item_info),并根据分类体系(category_system)为该项目确定最合适的分类。",
|
||||
"input": {
|
||||
"item_info": item,
|
||||
"category_system": category
|
||||
},
|
||||
"output_requirements": {
|
||||
"fields": {
|
||||
"primary_category": "一级分类名称",
|
||||
"secondary_category": "二级分类名称(如果有)",
|
||||
"tertiary_category": "三级分类名称(如果有)",
|
||||
"confidence": "分类置信度(0-1)",
|
||||
"reasoning": "简要说明分类理由"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return prompt_data
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
sort_json_path = "sort.json" # 分类信息文件
|
||||
category_data = read_json_file(sort_json_path)
|
||||
|
||||
# 文件路径
|
||||
table_name = 'data'
|
||||
conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名
|
||||
cursor = conn.cursor()
|
||||
sort_name = '产品类型'
|
||||
sort_value = ('住宿', '门票', '抢购')
|
||||
num_limit = "500"
|
||||
# SQL查询 - 随机选择500个指定产品类型的记录
|
||||
query = f"""
|
||||
SELECT * FROM data
|
||||
WHERE {sort_name} IN {sort_value}
|
||||
ORDER BY RANDOM()
|
||||
LIMIT {num_limit}
|
||||
"""
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchall()
|
||||
results = []
|
||||
for i, row in enumerate(result):
|
||||
prompt_data = create_prompt_json(row ,category_data)
|
||||
#print(prompt_data)
|
||||
result_entry = []
|
||||
result_entry.extend(row) # 将row的所有元素分别添加到result_entry
|
||||
result_entry.append(i)
|
||||
if i == 5:
|
||||
break
|
||||
|
||||
results.append(result_entry)
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
excel_file = '111.xlsx'
|
||||
# 获取数据库表的列名作为DataFrame的表头
|
||||
columns = [description[0] for description in cursor.description]
|
||||
columns.append('index') # 添加index列
|
||||
df.columns = columns
|
||||
|
||||
# 将第三列移到第一列
|
||||
cols = df.columns.tolist()
|
||||
third_col = cols.pop(2) # 第三列索引为2
|
||||
cols.insert(0, third_col)
|
||||
df = df[cols]
|
||||
|
||||
df.to_excel(excel_file, index=False)
|
||||
print(f"\n所有生成文本已保存到 {excel_file}")
|
||||
print(results)
|
||||
# 将结果保存为JSON文件
|
157
sql_vllm.py
Normal file
157
sql_vllm.py
Normal file
@ -0,0 +1,157 @@
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import pandas as pd
|
||||
import re
|
||||
import sys
|
||||
import sqlite3
|
||||
|
||||
import time
|
||||
|
||||
from sql_prompt import create_prompt_json
|
||||
from sql_prompt import read_json_file
|
||||
|
||||
# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载
|
||||
os.environ['VLLM_USE_MODELSCOPE']='True'
|
||||
if __name__ == "__main__":
|
||||
|
||||
sys.stdout.reconfigure(encoding='utf-8')
|
||||
|
||||
# 连接到数据库
|
||||
sort_json_path = "sort.json" # 分类信息文件
|
||||
category_data = read_json_file(sort_json_path)
|
||||
|
||||
# 文件路径
|
||||
table_name = 'data'
|
||||
conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名
|
||||
cursor = conn.cursor()
|
||||
sort_name = '产品类型'
|
||||
sort_value = ('住宿', '门票', '抢购')
|
||||
num_limit = "500"
|
||||
# SQL查询 - 随机选择500个指定产品类型的记录
|
||||
query = f"""
|
||||
SELECT * FROM data
|
||||
WHERE {sort_name} IN {sort_value}
|
||||
ORDER BY RANDOM()
|
||||
LIMIT {num_limit}
|
||||
"""
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# 修改 OpenAI 的 API 密钥和 API 基础 URL 以使用 vLLM 的 API 服务器。
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
#work_dir = './prompt'
|
||||
|
||||
results = []
|
||||
|
||||
for i, row in enumerate(result):
|
||||
# 记录开始时间
|
||||
loop_start_time = time.time()
|
||||
|
||||
prompt = create_prompt_json(row ,category_data) # 您的prompt文件路径
|
||||
prompt_str = json.dumps(prompt, ensure_ascii=False) # 将字典转换为JSON字符串,保留中文字符
|
||||
|
||||
# 使用聊天接口
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_str}
|
||||
]
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="Qwen3-8B",
|
||||
messages=messages
|
||||
)
|
||||
|
||||
# 获取生成的文本
|
||||
for choice in completion.choices:
|
||||
generated_text = choice.message.content
|
||||
|
||||
# 确保文本使用UTF-8编码
|
||||
if isinstance(generated_text, bytes):
|
||||
generated_text = generated_text.decode('utf-8')
|
||||
elif isinstance(generated_text, str):
|
||||
# 如果是字符串,但编码可能不是UTF-8,先转为bytes再解码
|
||||
try:
|
||||
generated_text = generated_text.encode('latin-1').decode('utf-8')
|
||||
except UnicodeError:
|
||||
# 如果上面的转换失败,保持原样
|
||||
pass
|
||||
|
||||
result_entry = []
|
||||
result_entry.extend(row) # 将row的所有元素分别添加到result_entry
|
||||
|
||||
# 提取</think>后的内容作为最终回答
|
||||
think_parts = generated_text.split("</think>")
|
||||
if len(think_parts) > 1:
|
||||
final_answer = think_parts[1].strip()
|
||||
# 继续处理的文本改为最终回答部分
|
||||
processing_text = final_answer
|
||||
else:
|
||||
# 如果没有</think>标记,使用原始文本
|
||||
processing_text = generated_text
|
||||
|
||||
# 处理提取出的文本内容
|
||||
try:
|
||||
# JSON解析失败,使用正则表达式
|
||||
patterns = {
|
||||
"primary_category": r'primary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
||||
"secondary_category": r'secondary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
||||
"tertiary_category": r'tertiary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
||||
"confidence": r'confidence["\s]*[::]\s*([0-9.]+)',
|
||||
"reasoning": r'reasoning["\s]*[::]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)'
|
||||
}
|
||||
|
||||
# 按顺序提取各个字段的值并追加到result_entry列表
|
||||
for field, pattern in patterns.items():
|
||||
match = re.search(pattern, processing_text, re.IGNORECASE)
|
||||
if match:
|
||||
# 确保匹配结果使用UTF-8编码
|
||||
match_text = match.group(1).strip()
|
||||
if isinstance(match_text, bytes):
|
||||
match_text = match_text.decode('utf-8')
|
||||
elif isinstance(match_text, str):
|
||||
try:
|
||||
match_text = match_text.encode('latin-1').decode('utf-8', errors='replace')
|
||||
except UnicodeError:
|
||||
pass
|
||||
result_entry.append(match_text)
|
||||
else:
|
||||
result_entry.append("")
|
||||
except Exception as e:
|
||||
print(f"处理样本 {i} 时出错: {e}")
|
||||
|
||||
# 添加处理后的结果
|
||||
results.append(result_entry)
|
||||
|
||||
# 计算并打印本次循环的执行时间
|
||||
loop_end_time = time.time()
|
||||
loop_duration = loop_end_time - loop_start_time
|
||||
print(f"第 {i+1} 次模型推理时间: {loop_duration:.2f} 秒")
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
excel_file = '500_Qwen3_8B_sort.xlsx'
|
||||
# 获取数据库表的列名作为DataFrame的表头
|
||||
columns = [description[0] for description in cursor.description]
|
||||
columns.append('primary_category')
|
||||
columns.append('secondary_category')
|
||||
columns.append('tertiary_category')
|
||||
columns.append('confidence')
|
||||
columns.append('reasoning')
|
||||
# 添加index列
|
||||
df.columns = columns
|
||||
|
||||
# 将第三列移到第一列
|
||||
cols = df.columns.tolist()
|
||||
third_col = cols.pop(2) # 第三列索引为2
|
||||
cols.insert(0, third_col)
|
||||
df = df[cols]
|
||||
|
||||
|
||||
df.to_excel(excel_file, index=False)
|
||||
print(f"\n所有生成文本已保存到 {excel_file}")
|
124
vllm_model.py
Normal file
124
vllm_model.py
Normal file
@ -0,0 +1,124 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import pandas as pd
|
||||
import re
|
||||
|
||||
# 读取您的prompt.json文件
|
||||
def load_from_json(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
# 从JSON文件构建提示词
|
||||
def build_prompt_from_json(prompt_data):
|
||||
# 根据您的JSON结构提取和组装内容
|
||||
instruction = prompt_data.get("instruction", "")
|
||||
task_description = prompt_data.get("task_description", "")
|
||||
input_data = prompt_data.get("input", {})
|
||||
output_requirements = prompt_data.get("output_requirements", {})
|
||||
|
||||
# 组装完整提示词
|
||||
prompt = f"{instruction}\n\n{task_description}\n\n"
|
||||
prompt += f"基本信息:\n{json.dumps(input_data.get('item_info', {}), ensure_ascii=False, indent=2)}\n\n"
|
||||
prompt += f"分类体系:\n{json.dumps(input_data.get('category_system', {}), ensure_ascii=False, indent=2)}\n\n"
|
||||
prompt += f"输出要求:\n{json.dumps(output_requirements, ensure_ascii=False, indent=2)}"
|
||||
|
||||
return prompt
|
||||
|
||||
def get_completion(prompts, model, tokenizer=None, temperature=0.6, top_p=0.95, top_k=20, min_p=0, max_tokens=4096, max_model_len=8192):
|
||||
stop_token_ids = [151645, 151643]
|
||||
# 创建采样参数。temperature 控制生成文本的多样性,top_p 控制核心采样的概率,top_k 通过限制候选词的数量来控制生成文本的质量和多样性, min_p 通过设置概率阈值来筛选候选词,从而在保证文本质量的同时增加多样性
|
||||
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids) # max_tokens 用于限制模型在推理过程中生成的最大输出长度
|
||||
# 初始化 vLLM 推理引擎
|
||||
llm = LLM(model=model, tokenizer=tokenizer, max_model_len=max_model_len,trust_remote_code=True) # max_model_len 用于限制模型在推理过程中可以处理的最大输入和输出长度之和。
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
return outputs
|
||||
|
||||
# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载
|
||||
os.environ['VLLM_USE_MODELSCOPE']='True'
|
||||
if __name__ == "__main__":
|
||||
random.seed(114514)
|
||||
random_numbers = [random.randint(0, 61929) for _ in range(20)]
|
||||
|
||||
# 初始化 vLLM 推理引擎
|
||||
model='/root/autodl-tmp/Qwen/Qwen3-8B' # 指定模型路径
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) # 加载分词器
|
||||
|
||||
work_dir = './prompt'
|
||||
data_dir = './data'
|
||||
|
||||
results = []
|
||||
|
||||
for i in random_numbers:
|
||||
|
||||
prompt_json_path = f"{work_dir}/{i}_prompt.json" # 您的prompt文件路径
|
||||
prompt_data = load_from_json(prompt_json_path)
|
||||
prompt = build_prompt_from_json(prompt_data)
|
||||
|
||||
data_json_path = f"{data_dir}/{i}.json"
|
||||
data_json = load_from_json(data_json_path)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True # 是否开启思考模式,默认为 True
|
||||
)
|
||||
|
||||
outputs = get_completion(text, model, tokenizer=None, temperature=0.6, top_p = 0.95, top_k=20, min_p=0) # 对于思考模式,官方建议使用以下参数:temperature = 0.6,TopP = 0.95,TopK = 20,MinP = 0。
|
||||
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
|
||||
result_entry = {}
|
||||
|
||||
# 首先添加样本ID字段
|
||||
result_entry["样本ID"] = i
|
||||
|
||||
# 添加原始数据字段
|
||||
if isinstance(data_json, dict):
|
||||
for key, value in data_json.items():
|
||||
result_entry[key] = value
|
||||
|
||||
# 提取</think>后的内容作为最终回答
|
||||
think_parts = generated_text.split("</think>")
|
||||
if len(think_parts) > 1:
|
||||
final_answer = think_parts[1].strip()
|
||||
# 继续处理的文本改为最终回答部分
|
||||
processing_text = final_answer
|
||||
else:
|
||||
# 如果没有</think>标记,使用原始文本
|
||||
processing_text = generated_text
|
||||
|
||||
# 处理提取出的文本内容
|
||||
try:
|
||||
# JSON解析失败,使用正则表达式
|
||||
patterns = {
|
||||
"primary_category": r'primary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
||||
"secondary_category": r'secondary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
||||
"tertiary_category": r'tertiary_category["\s]*[::]\s*["]*([^",\n}]+)',
|
||||
"confidence": r'confidence["\s]*[::]\s*([0-9.]+)',
|
||||
"reasoning": r'reasoning["\s]*[::]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)'
|
||||
}
|
||||
|
||||
for field, pattern in patterns.items():
|
||||
match = re.search(pattern, processing_text, re.IGNORECASE)
|
||||
if match:
|
||||
result_entry[field] = match.group(1).strip()
|
||||
except Exception as e:
|
||||
print(f"处理样本 {i} 时出错: {e}")
|
||||
|
||||
# 添加处理后的结果
|
||||
results.append(result_entry)
|
||||
|
||||
# 创建DataFrame并保存到Excel
|
||||
df = pd.DataFrame(results)
|
||||
excel_file = '20_random_Qwen3_8B_responses.xlsx'
|
||||
df.to_excel(excel_file, index=False)
|
||||
print(f"\n所有生成文本已保存到 {excel_file}")
|
Loading…
x
Reference in New Issue
Block a user