File size: 2,424 Bytes
d4a7de9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Dict, Any
import base64
import tempfile
import os
import sys
# 确保能导入 videollama2 模块(模型代码需要放同目录或已安装)
sys.path.append('./')
from videollama2 import model_init, mm_infer
from videollama2.utils import disable_torch_init
class EndpointHandler:
def __init__(self, path=""):
# 关闭torch自动初始化,避免重复加载
disable_torch_init()
# 模型路径,如果HF环境传入的path为空,就用默认的官方仓库地址
self.model_path = path or "DAMO-NLP-SG/VideoLLaMA2-7B-16F"
# 加载模型、处理器、分词器
self.model, self.processor, self.tokenizer = model_init(self.model_path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
期待输入数据格式:
{
"video": "<base64字符串>", # 视频文件base64编码
"prompt": "描述视频内容的自然语言指令"
}
或者
{
"image": "<base64字符串>", # 图片文件base64编码
"prompt": "描述图片内容的自然语言指令"
}
"""
# 判断输入模态
if "video" in data:
modal = "video"
file_b64 = data["video"]
elif "image" in data:
modal = "image"
file_b64 = data["image"]
else:
return {"error": "请求必须包含 'video' 或 'image' 字段"}
prompt = data.get("prompt", "Describe the content.")
# 临时写入二进制文件,供 processor 读取
suffix = ".mp4" if modal == "video" else ".png"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(base64.b64decode(file_b64))
tmp_path = tmp_file.name
try:
# 处理输入,调用模型推理
inputs = self.processor[modal](tmp_path)
output = mm_infer(
inputs,
prompt,
model=self.model,
tokenizer=self.tokenizer,
do_sample=False,
modal=modal
)
finally:
# 清理临时文件
os.remove(tmp_path)
# 返回结构统一,方便调用方解析
return {
"modal": modal,
"prompt": prompt,
"result": output
}
|