Nova-fine-tuning / src /chatbot.py
StevesInfinityDrive's picture
Update src/chatbot.py
b06f43c verified
raw
history blame contribute delete
860 Bytes
import torch
from model_loader import load_model
# Load model and tokenizer
model, tokenizer, device = load_model()
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=150,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.2,
do_sample=True # ✅ Ensures diverse responses
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Test the chatbot function
if __name__ == "__main__":
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit"]:
break
bot_response = generate_response(user_input)
print(f"Bot: {bot_response}")