| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
| from PIL import Image |
| from diffusers import StableDiffusionPipeline |
| from transformers import CLIPTokenizer |
| import os |
| import zipfile |
| import gradio as gr |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| class CustomImageDataset(Dataset): |
| def __init__(self, images, prompts, transform=None): |
| self.images = images |
| self.prompts = prompts |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, idx): |
| image = self.images[idx] |
| if self.transform: |
| image = self.transform(image) |
| prompt = self.prompts[idx] |
| return image, prompt |
|
|
| |
| def fine_tune_model(images, prompts, model_save_path, num_epochs=3): |
| transform = transforms.Compose([ |
| transforms.Resize((512, 512)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| dataset = CustomImageDataset(images, prompts, transform) |
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
|
|
| |
| pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) |
|
|
| |
| vae = pipeline.vae.to(device) |
| unet = pipeline.unet.to(device) |
| text_encoder = pipeline.text_encoder.to(device) |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
| optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) |
|
|
| |
| timesteps = torch.linspace(0, 1, steps=5).to(device) |
|
|
| |
| for epoch in range(num_epochs): |
| for i, (images, prompts) in enumerate(dataloader): |
| images = images.to(device) |
|
|
| |
| inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device) |
|
|
| latents = vae.encode(images).latent_dist.sample() * 0.18215 |
| text_embeddings = text_encoder(inputs.input_ids).last_hidden_state |
|
|
| noise = torch.randn_like(latents).to(device) |
| noisy_latents = latents + noise |
|
|
| |
| timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float() |
| pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample |
|
|
| loss = torch.nn.functional.mse_loss(pred_noise, noise) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| |
| pipeline.save_pretrained(model_save_path) |
|
|
| |
| def tensor_to_pil(tensor): |
| tensor = tensor.squeeze().cpu().clamp(0, 1) |
| tensor = transforms.ToPILImage()(tensor) |
| return tensor |
|
|
| |
| def generate_images(pipeline, prompt): |
| with torch.no_grad(): |
| |
| output = pipeline(prompt) |
|
|
| |
| image = output.images[0] |
| return image |
|
|
| |
| def zip_model(model_path): |
| zip_path = f"{model_path}.zip" |
| with zipfile.ZipFile(zip_path, "w") as zipf: |
| for root, _, files in os.walk(model_path): |
| for file in files: |
| zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path)) |
| return zip_path |
|
|
| |
| def save_uploaded_file(uploaded_file, save_path): |
| |
| with open(save_path, 'wb') as f: |
| f.write(uploaded_file.data) |
| return f"File saved at {save_path}" |
|
|
| |
| def start_fine_tuning(uploaded_files, prompts, num_epochs): |
| images = [Image.open(file).convert("RGB") for file in uploaded_files] |
| model_save_path = "fine_tuned_model" |
| fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs)) |
| return "Fine-tuning completed! Model is ready for download." |
|
|
| def download_model(): |
| model_save_path = "fine_tuned_model" |
| if os.path.exists(model_save_path): |
| return zip_model(model_save_path) |
| else: |
| return None |
|
|
| def generate_new_image(prompt): |
| model_save_path = "fine_tuned_model" |
| if os.path.exists(model_save_path): |
| pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device) |
| else: |
| pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) |
| image = generate_images(pipeline, prompt) |
| image_path = "generated_image.png" |
| image.save(image_path) |
| return image_path |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images") |
|
|
| with gr.Tab("Fine-Tune Model"): |
| with gr.Row(): |
| uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple") |
| with gr.Row(): |
| prompts = gr.Textbox(label="Enter Prompts (comma-separated)") |
| num_epochs = gr.Number(label="Number of Epochs", value=3) |
| with gr.Row(): |
| fine_tune_button = gr.Button("Start Fine-Tuning") |
| fine_tune_output = gr.Textbox(label="Output") |
|
|
| fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output) |
|
|
| with gr.Tab("Download Fine-Tuned Model"): |
| download_button = gr.Button("Download Fine-Tuned Model") |
| download_output = gr.File() |
|
|
| download_button.click(download_model, [], download_output) |
|
|
| with gr.Tab("Generate New Images"): |
| prompt_input = gr.Textbox(label="Enter a Prompt") |
| generate_button = gr.Button("Generate Image") |
| generated_image = gr.Image(label="Generated Image") |
|
|
| generate_button.click(generate_new_image, [prompt_input], generated_image) |
|
|
| demo.launch() |
|
|