aYeShaSiddiqA commited on
Commit
25d0bb1
·
verified ·
1 Parent(s): ce4d8c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -1,36 +1,36 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
3
 
4
- model_id = "tiiuae/falcon-rw-1b" # lightweight model that works on CPU
5
-
6
- tokenizer = AutoTokenizer.from_pretrained(model_id)
7
- model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
8
 
9
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
10
 
11
- def generate_story(character_1, character_2, theme, setting, scenario):
12
- prompt = (
13
- f"Characters: {character_1}, {character_2}\n"
14
- f"Theme: {theme}\n"
15
- f"Setting: {setting}\n"
16
- f"Scenario: {scenario}\n"
17
- f"Write a story suitable for children aged 6 to 12:"
 
 
 
 
18
  )
19
- output = generator(prompt, max_new_tokens=300, temperature=0.8)[0]["generated_text"]
20
- return output[len(prompt):].strip()
21
-
22
- demo = gr.Interface(
23
- fn=generate_story,
24
- inputs=[
25
- gr.Textbox(label="Character 1"),
26
- gr.Textbox(label="Character 2"),
27
- gr.Textbox(label="Theme"),
28
- gr.Textbox(label="Setting"),
29
- gr.Textbox(label="Scenario")
30
- ],
31
- outputs=gr.Textbox(label="📖 Generated Story"),
32
- title="Genieverse Lite Story Generator",
33
- description="Enter characters, theme, and scenario to create a children's story!"
34
- )
35
 
36
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ import torch
4
 
5
+ model_name = "ajibawa-2023/Young-Children-Storyteller-Mistral-7B"
6
+ quant_config = BitsAndBytesConfig(
7
+ load_in_4bit=True,
8
+ bnb_4bit_compute_dtype=torch.float16,
9
+ bnb_4bit_use_double_quant=True,
10
+ bnb_4bit_quant_type="nf4"
11
+ )
12
 
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ quantization_config=quant_config,
17
+ device_map="auto",
18
+ torch_dtype=torch.float16,
19
+ trust_remote_code=True
20
+ )
21
 
22
+ def generate_story(prompt, max_length=400, temperature=0.7, top_p=0.9):
23
+ formatted_prompt = f"### Instruction:\nCreate a story for young children about: {prompt}\n\n### Response:\n"
24
+ inputs = tokenizer.encode(formatted_prompt, return_tensors="pt").to(model.device)
25
+ outputs = model.generate(
26
+ inputs,
27
+ max_length=max_length,
28
+ temperature=temperature,
29
+ top_p=top_p,
30
+ do_sample=True,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ repetition_penalty=1.1
33
  )
34
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split("### Response:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ gr.Interface(fn=generate_story, inputs="text", outputs="text").launch()