361 lines
15 KiB
Python
361 lines
15 KiB
Python
from vllm import LLM, SamplingParams
|
||
from transformers import AutoTokenizer
|
||
from vllm.assets.image import ImageAsset
|
||
from vllm.assets.video import VideoAsset
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional, ClassVar, List, Tuple
|
||
import numpy as np
|
||
from numpy.typing import NDArray
|
||
from PIL import Image
|
||
import librosa
|
||
import librosa.util
|
||
import cv2
|
||
import os
|
||
from openai import OpenAI
|
||
import base64
|
||
# import ray # 注释掉Ray导入
|
||
|
||
# # 使用本地模式初始化Ray,避免分布式通信问题
|
||
# ray.init(local_mode=True, ignore_reinit_error=True)
|
||
|
||
# 设置环境变量,禁用在线检查
|
||
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||
|
||
@dataclass(frozen=True)
|
||
class VideoAsset:
|
||
name: str
|
||
num_frames: int = -1
|
||
|
||
_NAME_TO_FILE: ClassVar[dict[str, str]] = {
|
||
"baby_reading": "sample_demo_1.mp4",
|
||
}
|
||
|
||
@property
|
||
def filename(self) -> str:
|
||
return self._NAME_TO_FILE[self.name]
|
||
|
||
@property
|
||
def pil_images(self) -> list[Image.Image]:
|
||
video_path = download_video_asset(self.filename)
|
||
return video_to_pil_images_list(video_path, self.num_frames)
|
||
|
||
@property
|
||
def np_ndarrays(self) -> NDArray:
|
||
video_path = download_video_asset(self.filename)
|
||
return video_to_ndarrays(video_path, self.num_frames)
|
||
|
||
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
|
||
video_path = download_video_asset(self.filename)
|
||
return librosa.load(video_path, sr=sampling_rate)[0]
|
||
|
||
@dataclass(frozen=True)
|
||
class LocalVideoAsset:
|
||
local_path: str
|
||
name: str = "local_video"
|
||
num_frames: int = -1
|
||
|
||
@property
|
||
def filename(self) -> str:
|
||
return self.local_path
|
||
|
||
@property
|
||
def pil_images(self) -> list[Image.Image]:
|
||
return video_to_pil_images_list(self.filename, self.num_frames)
|
||
|
||
@property
|
||
def np_ndarrays(self) -> NDArray:
|
||
return video_to_ndarrays(self.filename, self.num_frames)
|
||
|
||
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
|
||
try:
|
||
if not os.path.exists(self.filename):
|
||
print(f"音频文件不存在: {self.filename}")
|
||
return np.zeros(1) # 返回空数组
|
||
return librosa.load(self.filename, sr=sampling_rate)[0]
|
||
except Exception as e:
|
||
print(f"加载音频时出错: {e}")
|
||
return np.zeros(1) # 出错时返回空数组
|
||
|
||
# 辅助函数实现
|
||
def download_video_asset(filename: str) -> str:
|
||
# 如果路径是绝对路径或相对路径,直接返回
|
||
if filename.startswith("/") or filename.startswith("./"):
|
||
return filename
|
||
# 否则执行下载逻辑(原实现)
|
||
return f"/path/to/downloaded/{filename}"
|
||
|
||
def video_to_pil_images_list(video_path: str, num_frames: int) -> list[Image.Image]:
|
||
"""将视频转换为PIL图像列表"""
|
||
if not os.path.exists(video_path):
|
||
print(f"视频文件不存在: {video_path}")
|
||
return []
|
||
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
print(f"无法打开视频: {video_path}")
|
||
return []
|
||
|
||
# 获取视频帧数
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
duration = total_frames / fps if fps > 0 else 0
|
||
|
||
print(f"视频信息: 总帧数={total_frames}, FPS={fps:.2f}, 时长={duration:.2f}秒")
|
||
|
||
# 如果指定了帧数,设置采样间隔;否则读取所有帧
|
||
if num_frames > 0 and num_frames < total_frames:
|
||
frame_interval = total_frames / num_frames
|
||
print(f"将提取 {num_frames} 帧,采样间隔为每 {frame_interval:.2f} 帧")
|
||
else:
|
||
frame_interval = 1
|
||
num_frames = total_frames
|
||
print(f"将提取所有 {total_frames} 帧")
|
||
|
||
pil_images = []
|
||
frame_count = 0
|
||
success = True
|
||
last_progress = -1
|
||
|
||
while success and len(pil_images) < num_frames:
|
||
# 读取下一帧
|
||
success, frame = cap.read()
|
||
if not success:
|
||
break
|
||
|
||
# 按间隔采样帧
|
||
if frame_count % max(1, int(frame_interval)) == 0:
|
||
# OpenCV使用BGR,转为RGB
|
||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
# 转为PIL图像
|
||
pil_image = Image.fromarray(rgb_frame)
|
||
pil_images.append(pil_image)
|
||
|
||
# 显示进度(每10%显示一次)
|
||
progress = int(len(pil_images) / num_frames * 10)
|
||
if progress > last_progress:
|
||
print(f"提取进度: {len(pil_images)}/{num_frames} ({len(pil_images)/num_frames*100:.1f}%)")
|
||
last_progress = progress
|
||
|
||
frame_count += 1
|
||
|
||
cap.release()
|
||
print(f"从视频中共提取了 {len(pil_images)} 帧")
|
||
return pil_images
|
||
|
||
def video_to_ndarrays(video_path: str, num_frames: int) -> NDArray:
|
||
"""将视频转换为NumPy数组"""
|
||
pil_images = video_to_pil_images_list(video_path, num_frames)
|
||
if not pil_images:
|
||
print(f"未能从视频中提取帧: {video_path}")
|
||
return np.zeros((1, 224, 224, 3))
|
||
|
||
# 将PIL图像列表转换为NumPy数组
|
||
arrays = []
|
||
for img in pil_images:
|
||
# 调整图像大小为统一尺寸
|
||
img_resized = img.resize((224, 224))
|
||
# 转换为NumPy数组
|
||
arr = np.array(img_resized)
|
||
arrays.append(arr)
|
||
|
||
# 堆叠为单个NumPy数组,形状为[num_frames, height, width, channels]
|
||
stacked_array = np.stack(arrays, axis=0)
|
||
print(f"NumPy数组形状: {stacked_array.shape}")
|
||
return stacked_array
|
||
|
||
def encode_image(image_path):
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||
|
||
def encode_audio(audio_path):
|
||
with open(audio_path, "rb") as audio_file:
|
||
return base64.b64encode(audio_file.read()).decode("utf-8")
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 使用本地视频资源
|
||
# local_video = LocalVideoAsset(
|
||
# local_path="/root/autodl-tmp/hot_video_analyse/source/sample_demo_1.mp4",
|
||
# num_frames = -1 # 限制帧数以加快测试速度
|
||
# )
|
||
# print("本地资源:", local_video.filename)
|
||
|
||
# 获取PIL图像列表(实际会调用download_video_asset和转换函数)
|
||
# pil_images = local_video.pil_images
|
||
# print("PIL图像数量:", len(pil_images))
|
||
# print(pil_images[0])
|
||
|
||
# 获取NumPy数组
|
||
# print("\n=== 加载视频帧 ===")
|
||
# np_arrays = local_video.np_ndarrays
|
||
# print(f"视频数组形状: {np_arrays.shape}")
|
||
# scene_change_arrays = np.load('/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_arrays.npy')
|
||
# print(f"场景变化帧数组形状: {scene_change_arrays.shape}")
|
||
|
||
# 获取音频数据
|
||
# audio = local_video.get_audio(sampling_rate=16000)
|
||
# print("音频数据形状:", audio.shape)
|
||
# print(type(audio))
|
||
base64_audio = encode_audio("/root/autodl-tmp/hot_video_analyse/source/transcription/sample_demo_1_audio.wav")
|
||
client = OpenAI(
|
||
# 若没有配置环境变量,请用阿里云百炼API Key将下行替换为:api_key="sk-xxx",
|
||
api_key="EMPTY",
|
||
base_url="http://localhost:8000/v1",
|
||
)
|
||
|
||
base64_image_0 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000034_t1.13s_ssim0.5673.jpg")
|
||
base64_image_1 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000117_t3.90s_ssim0.1989.jpg")
|
||
base64_image_2 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000119_t3.97s_ssim0.2138.jpg")
|
||
base64_image_3 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000140_t4.67s_ssim0.5160.jpg")
|
||
base64_image_4 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000160_t5.33s_ssim0.4934.jpg")
|
||
base64_image_5 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000180_t6.00s_ssim0.3577.jpg")
|
||
base64_image_6 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000201_t6.70s_ssim0.3738.jpg")
|
||
base64_image_7 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000222_t7.40s_ssim0.6104.jpg")
|
||
base64_image_8 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000243_t8.10s_ssim0.5099.jpg")
|
||
base64_image_9 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000261_t8.70s_ssim0.4735.jpg")
|
||
base64_image_10 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000281_t9.37s_ssim0.2703.jpg")
|
||
base64_image_11 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000301_t10.03s_ssim0.2772.jpg")
|
||
base64_image_12 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000321_t10.70s_ssim0.3721.jpg")
|
||
base64_image_13 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000341_t11.37s_ssim0.3700.jpg")
|
||
base64_image_14 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000361_t12.03s_ssim0.3494.jpg")
|
||
base64_image_15 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000382_t12.73s_ssim0.3423.jpg")
|
||
base64_image_16 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000402_t13.40s_ssim0.3252.jpg")
|
||
base64_image_17 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000421_t14.03s_ssim0.2086.jpg")
|
||
|
||
|
||
completion = client.chat.completions.create(
|
||
model="/root/autodl-tmp/llm", # 使用vLLM服务中实际注册的模型ID
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
# 将视频帧作为多个图像处理
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_1}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_2}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_3}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_4}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_5}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_6}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_7}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_8}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_9}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_10}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_11}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_12}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_13}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_14}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_15}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_16}"}
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_17}"}
|
||
},
|
||
|
||
# 音频输入
|
||
{
|
||
"type": "input_audio",
|
||
"input_audio": {
|
||
"data": base64_audio,
|
||
"format": "wav"
|
||
}
|
||
},
|
||
# 文本提示
|
||
{
|
||
"type": "text",
|
||
"text": "描述一下关于这个视频的浅层特征"
|
||
},
|
||
],
|
||
}
|
||
],
|
||
# 设置输出数据的模态,当前支持两种:["text","audio"]、["text"]
|
||
modalities=["text"],
|
||
#audio={"voice": "Cherry", "format": "wav"},
|
||
# stream 必须设置为 True,否则会报错
|
||
stream=True,
|
||
stream_options={"include_usage": True},
|
||
)
|
||
#"形容一下这段音频。这些图像是一个视频的连续帧,请结合音频描述这个视频的具体过程和内容。并解释一下这个视频吸引人眼球的原因"
|
||
#"识别这段音频,并描述一下这段音频的内容。"
|
||
# 收集完整的响应内容
|
||
full_response = ""
|
||
usage_info = None
|
||
|
||
print("正在生成回复...")
|
||
for chunk in completion:
|
||
if chunk.choices:
|
||
delta = chunk.choices[0].delta
|
||
if delta.content:
|
||
# 拼接内容
|
||
full_response += delta.content
|
||
# 实时显示(可选)
|
||
print(delta.content, end='', flush=True)
|
||
else:
|
||
# 保存使用情况信息
|
||
usage_info = chunk.usage
|
||
|
||
# 输出完整的响应
|
||
print("\n" + "="*50)
|
||
print("完整回复:")
|
||
print("="*50)
|
||
print(full_response)
|
||
|
||
# 输出使用情况信息
|
||
if usage_info:
|
||
print("\n" + "="*50)
|
||
print("使用情况:")
|
||
print("="*50)
|
||
print(usage_info)
|
||
|
||
|
||
|