-
Notifications
You must be signed in to change notification settings - Fork 47
/
ml.py
44 lines (33 loc) · 1.03 KB
/
ml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline
from PIL.Image import Image
token_path = Path("token.txt")
token = token_path.read_text().strip()
# get your token at https://huggingface.co/settings/tokens
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=token,
)
pipe.to("cuda")
# prompt = "a photograph of an astronaut riding a horse"
# image = pipe(prompt)["sample"][0]
def obtain_image(
prompt: str,
*,
seed: int | None = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
) -> Image:
generator = None if seed is None else torch.Generator("cuda").manual_seed(seed)
print(f"Using device: {pipe.device}")
image: Image = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return image
# image = obtain_image(prompt, num_inference_steps=5, seed=1024)