Update handler.py
Browse files- handler.py +16 -17
handler.py
CHANGED
|
@@ -43,22 +43,21 @@ class EndpointHandler:
|
|
| 43 |
)
|
| 44 |
|
| 45 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
| 46 |
-
"""
|
| 47 |
-
data 格式:
|
| 48 |
-
{
|
| 49 |
-
"inputs": "your prompt here"
|
| 50 |
-
}
|
| 51 |
-
"""
|
| 52 |
prompt = data["inputs"]
|
| 53 |
-
|
| 54 |
-
#
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
with torch.inference_mode():
|
| 59 |
-
output_ids = self.model.generate(
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
prompt = data["inputs"]
|
| 47 |
+
|
| 48 |
+
# ① 自动抓 embedding 所在 GPU
|
| 49 |
+
first_device = next(self.model.parameters()).device
|
| 50 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(first_device)
|
| 51 |
+
|
| 52 |
+
# ② 生成(其余逻辑不变)
|
| 53 |
with torch.inference_mode():
|
| 54 |
+
output_ids = self.model.generate(
|
| 55 |
+
**inputs,
|
| 56 |
+
**self.generation_kwargs,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"generated_text": self.tokenizer.decode(
|
| 61 |
+
output_ids[0], skip_special_tokens=True
|
| 62 |
+
)
|
| 63 |
+
}
|