| | |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| | import gradio as gr |
| |
|
| | model_id = "witfoo/witq-1.0" |
| | dtype = torch.float16 |
| |
|
| | device = "auto" |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=dtype, |
| | device_map=device, |
| | ) |
| |
|
| | preamble = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." |
| |
|
| |
|
| |
|
| | def input_tokens(instruction, prompt): |
| | messages = [ |
| | {"role": "system", "content": preamble + " " + instruction}, |
| | {"role": "user", "content": prompt}, |
| | ] |
| | inputs = tokenizer.apply_chat_template( |
| | messages, |
| | add_generation_prompt=True, |
| | return_tensors="pt" |
| | ).to(model.device) |
| | return inputs |
| |
|
| |
|
| |
|
| | def generate_response(instruction, input_text): |
| | input_ids = input_tokens(instruction, input_text) |
| | terminators = [ |
| | tokenizer.eos_token_id, |
| | tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| | ] |
| | outputs = model.generate( |
| | input_ids, |
| | max_new_tokens=256, |
| | eos_token_id=terminators, |
| | do_sample=True, |
| | temperature=0.6, |
| | top_p=0.9, |
| | ) |
| | |
| | |
| | response = outputs[0][input_ids.shape[-1]:] |
| | result = tokenizer.decode(response, skip_special_tokens=True) |
| | return result |
| |
|
| | def chatbot(instructions, input_text): |
| | torch.cuda.empty_cache() |
| | response = generate_response(instructions, input_text) |
| | return response |
| |
|
| | trained_instructions = [ |
| | "Answer this question", |
| | "Create a JSON artifact from the message", |
| | "Identify this syslog message", |
| | "Explain this syslog message", |
| | ] |
| |
|
| | iface = gr.Interface( |
| | fn=chatbot, |
| | inputs=[ |
| | gr.Dropdown(choices=trained_instructions, label="Instruction"), |
| | gr.Textbox(lines=2, placeholder="Enter your input here...", label="Input Text") |
| | ], |
| | outputs=gr.Textbox(label="Response"), |
| | title="WitQ Chatbot" |
| | ) |
| |
|
| |
|
| | app = gr.Blocks() |
| |
|
| | with app: |
| | iface.render() |
| |
|
| |
|
| | app.launch() |