🎨 Build an AI Image Generator

Create a web app for generating images with Stable Diffusion

Project Overview

Build a full-stack web application that generates images from text using Stable Diffusion. Users can enter prompts and see AI-generated artwork in real-time.

What You'll Build:

  • Web interface for text-to-image generation
  • Backend API with Stable Diffusion
  • Image gallery to save generations
  • Advanced controls (guidance, steps, size)

🛠️ Tech Stack

📁 Project Structure

image-generator/
├── backend/
│   ├── main.py              # FastAPI app
│   ├── generator.py         # SD pipeline
│   └── requirements.txt
├── frontend/
│   ├── index.html
│   ├── style.css
│   └── script.js
└── generated_images/        # Output folder

🔧 Backend Setup

requirements.txt

fastapi
uvicorn
diffusers
transformers
accelerate
torch
pillow
python-multipart

generator.py

from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import io

class ImageGenerator:
    def __init__(self):
        model_id = "stabilityai/stable-diffusion-2-1"
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )
        self.pipe = self.pipe.to("cuda")
        self.pipe.enable_attention_slicing()
    
    def generate(
        self,
        prompt: str,
        negative_prompt: str = "",
        num_steps: int = 50,
        guidance_scale: float = 7.5,
        width: int = 512,
        height: int = 512,
        seed: int = None
    ) -> Image.Image:
        generator = None
        if seed is not None:
            generator = torch.Generator("cuda").manual_seed(seed)
        
        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            width=width,
            height=height,
            generator=generator
        ).images[0]
        
        return image

# Initialize generator (singleton)
generator = ImageGenerator()

main.py (FastAPI)

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import io
import uuid
from pathlib import Path
from generator import generator

app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

class GenerationRequest(BaseModel):
    prompt: str
    negative_prompt: str = "blurry, low quality, deformed"
    num_steps: int = 50
    guidance_scale: float = 7.5
    width: int = 512
    height: int = 512
    seed: int = None

@app.post("/generate")
async def generate_image(request: GenerationRequest):
    try:
        # Generate image
        image = generator.generate(
            prompt=request.prompt,
            negative_prompt=request.negative_prompt,
            num_steps=request.num_steps,
            guidance_scale=request.guidance_scale,
            width=request.width,
            height=request.height,
            seed=request.seed
        )
        
        # Save image
        output_dir = Path("generated_images")
        output_dir.mkdir(exist_ok=True)
        
        filename = f"{uuid.uuid4()}.png"
        filepath = output_dir / filename
        image.save(filepath)
        
        # Return image as stream
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        
        return StreamingResponse(
            img_byte_arr,
            media_type="image/png",
            headers={"X-Filename": filename}
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/gallery")
async def get_gallery():
    output_dir = Path("generated_images")
    images = list(output_dir.glob("*.png"))
    return {"images": [img.name for img in images[-20:]]}  # Last 20

@app.get("/image/{filename}")
async def get_image(filename: str):
    filepath = Path("generated_images") / filename
    if not filepath.exists():
        raise HTTPException(status_code=404, detail="Image not found")
    return FileResponse(filepath)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

🎨 Frontend (HTML + JavaScript)

index.html

<!DOCTYPE html>
<html>
<head>
    <title>AI Image Generator</title>
    <link rel="stylesheet" href="style.css">
</head>
<body>
    <div class="container">
        <h1>🎨 AI Image Generator</h1>
        
        <div class="controls">
            <textarea id="prompt" placeholder="Describe your image..."></textarea>
            <textarea id="negative-prompt" placeholder="What to avoid..."></textarea>
            
            <div class="sliders">
                <label>Steps: <span id="steps-value">50</span></label>
                <input type="range" id="steps" min="20" max="100" value="50">
                
                <label>Guidance: <span id="guidance-value">7.5</span></label>
                <input type="range" id="guidance" min="1" max="20" step="0.5" value="7.5">
            </div>
            
            <button id="generate-btn">Generate Image</button>
        </div>
        
        <div id="loading" style="display:none;">Generating...</div>
        <div id="result"></div>
        
        <h2>Gallery</h2>
        <div id="gallery"></div>
    </div>
    
    <script src="script.js"></script>
</body>
</html>

script.js

const API_URL = "http://localhost:8000";

document.getElementById("generate-btn").addEventListener("click", generateImage);

// Update slider values
document.getElementById("steps").addEventListener("input", (e) => {
    document.getElementById("steps-value").textContent = e.target.value;
});

document.getElementById("guidance").addEventListener("input", (e) => {
    document.getElementById("guidance-value").textContent = e.target.value;
});

async function generateImage() {
    const prompt = document.getElementById("prompt").value;
    if (!prompt) {
        alert("Please enter a prompt!");
        return;
    }
    
    const loading = document.getElementById("loading");
    const result = document.getElementById("result");
    const btn = document.getElementById("generate-btn");
    
    // Show loading
    loading.style.display = "block";
    result.innerHTML = "";
    btn.disabled = true;
    
    try {
        const response = await fetch(`${API_URL}/generate`, {
            method: "POST",
            headers: {"Content-Type": "application/json"},
            body: JSON.stringify({
                prompt: prompt,
                negative_prompt: document.getElementById("negative-prompt").value,
                num_steps: parseInt(document.getElementById("steps").value),
                guidance_scale: parseFloat(document.getElementById("guidance").value),
                width: 512,
                height: 512
            })
        });
        
        if (!response.ok) throw new Error("Generation failed");
        
        // Get image
        const blob = await response.blob();
        const imageUrl = URL.createObjectURL(blob);
        
        // Display
        result.innerHTML = `<img src="${imageUrl}" alt="Generated">`;
        
        // Refresh gallery
        loadGallery();
        
    } catch (error) {
        alert("Error: " + error.message);
    } finally {
        loading.style.display = "none";
        btn.disabled = false;
    }
}

async function loadGallery() {
    try {
        const response = await fetch(`${API_URL}/gallery`);
        const data = await response.json();
        
        const gallery = document.getElementById("gallery");
        gallery.innerHTML = data.images.map(img => 
            `<img src="${API_URL}/image/${img}" class="gallery-img">`
        ).join("");
    } catch (error) {
        console.error("Failed to load gallery:", error);
    }
}

// Load gallery on page load
loadGallery();

🚀 Running the App

# Install dependencies
cd backend
pip install -r requirements.txt

# Run backend
python main.py

# Open frontend
# Open frontend/index.html in browser
# Or use a simple HTTP server:
cd frontend
python -m http.server 8080

🎯 Key Takeaways